From ca5bae1e3b482266aaf221c13e4911d095b60c7c Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 5 Jun 2023 09:52:03 -0600 Subject: [PATCH 1/6] Implement OIDC SASL mechanism in sync (#1107) JAVA-4980 --- .evergreen/prepare-oidc-get-tokens-docker.sh | 50 + .evergreen/prepare-oidc-server-docker.sh | 50 + .../src/test/unit/util/ThreadTestHelpers.java | 12 +- .../com/mongodb/AuthenticationMechanism.java | 7 + .../main/com/mongodb/ConnectionString.java | 5 + .../src/main/com/mongodb/MongoCredential.java | 224 ++++- .../src/main/com/mongodb/internal/Locks.java | 16 + .../internal/connection/Authenticator.java | 11 + .../internal/connection/AwsAuthenticator.java | 54 +- .../connection/InternalStreamConnection.java | 67 +- .../InternalStreamConnectionFactory.java | 21 +- .../InternalStreamConnectionInitializer.java | 20 +- .../connection/MongoCredentialWithCache.java | 28 +- .../connection/OidcAuthenticator.java | 678 +++++++++++++ .../connection/SaslAuthenticator.java | 72 +- .../connection/ScramShaAuthenticator.java | 48 +- .../com/mongodb/client/TestHelper.java | 47 + .../connection/TestCommandListener.java | 45 +- .../auth/{ => legacy}/connection-string.json | 143 ++- .../auth/reauthenticate_with_retry.json | 191 ++++ .../auth/reauthenticate_without_retry.json | 191 ++++ .../com/mongodb/AuthConnectionStringTest.java | 54 +- .../com/mongodb/client/unified/Entities.java | 6 +- .../client/unified/UnifiedAuthTest.java | 39 + .../OidcAuthenticationProseTests.java | 937 ++++++++++++++++++ 25 files changed, 2876 insertions(+), 140 deletions(-) create mode 100755 .evergreen/prepare-oidc-get-tokens-docker.sh create mode 100755 .evergreen/prepare-oidc-server-docker.sh create mode 100644 driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java create mode 100644 driver-core/src/test/functional/com/mongodb/client/TestHelper.java rename driver-core/src/test/resources/auth/{ => legacy}/connection-string.json (74%) create mode 100644 driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json create mode 100644 driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json create mode 100644 driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java create mode 100644 driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java diff --git a/.evergreen/prepare-oidc-get-tokens-docker.sh b/.evergreen/prepare-oidc-get-tokens-docker.sh new file mode 100755 index 00000000000..e904d5d2b89 --- /dev/null +++ b/.evergreen/prepare-oidc-get-tokens-docker.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# DRIVERS_TOOLS The path to evergreeen tools +# OIDC_AWS_* Required OIDC_AWS_* env variables must be configured +# +# Environment variables used as output: +# OIDC_TESTS_ENABLED Allows running OIDC tests +# OIDC_TOKEN_DIR The path to generated OIDC AWS tokens +# AWS_WEB_IDENTITY_TOKEN_FILE The path to AWS token for device workflow + +if [ -z ${DRIVERS_TOOLS+x} ]; then + echo "DRIVERS_TOOLS. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ROLE_ARN+x} ]; then + echo "OIDC_AWS_ROLE_ARN. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_SECRET_ACCESS_KEY+x} ]; then + echo "OIDC_AWS_SECRET_ACCESS_KEY. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ACCESS_KEY_ID+x} ]; then + echo "OIDC_AWS_ACCESS_KEY_ID. is not set"; + exit 1 +fi + +export AWS_ROLE_ARN=${OIDC_AWS_ROLE_ARN} +export AWS_SECRET_ACCESS_KEY=${OIDC_AWS_SECRET_ACCESS_KEY} +export AWS_ACCESS_KEY_ID=${OIDC_AWS_ACCESS_KEY_ID} +export OIDC_FOLDER=${DRIVERS_TOOLS}/.evergreen/auth_oidc +export OIDC_TOKEN_DIR=${OIDC_FOLDER}/test_tokens +export AWS_WEB_IDENTITY_TOKEN_FILE=${OIDC_TOKEN_DIR}/test1 +export OIDC_TESTS_ENABLED=true + +echo "Configuring OIDC server for local authentication tests" + +cd ${OIDC_FOLDER} +DRIVERS_TOOLS=${DRIVERS_TOOLS} ./oidc_get_tokens.sh \ No newline at end of file diff --git a/.evergreen/prepare-oidc-server-docker.sh b/.evergreen/prepare-oidc-server-docker.sh new file mode 100755 index 00000000000..0fcd1ed4194 --- /dev/null +++ b/.evergreen/prepare-oidc-server-docker.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# DRIVERS_TOOLS The path to evergreeen tools +# OIDC_AWS_* OIDC_AWS_* env variables must be configured +# +# Environment variables used as output: +# OIDC_TESTS_ENABLED Allows running OIDC tests +# OIDC_TOKEN_DIR The path to generated tokens +# AWS_WEB_IDENTITY_TOKEN_FILE The path to AWS token for device workflow + +if [ -z ${DRIVERS_TOOLS+x} ]; then + echo "DRIVERS_TOOLS. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ROLE_ARN+x} ]; then + echo "OIDC_AWS_ROLE_ARN. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_SECRET_ACCESS_KEY+x} ]; then + echo "OIDC_AWS_SECRET_ACCESS_KEY. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ACCESS_KEY_ID+x} ]; then + echo "OIDC_AWS_ACCESS_KEY_ID. is not set"; + exit 1 +fi + +export AWS_ROLE_ARN=${OIDC_AWS_ROLE_ARN} +export AWS_SECRET_ACCESS_KEY=${OIDC_AWS_SECRET_ACCESS_KEY} +export AWS_ACCESS_KEY_ID=${OIDC_AWS_ACCESS_KEY_ID} +export OIDC_FOLDER=${DRIVERS_TOOLS}/.evergreen/auth_oidc +export OIDC_TOKEN_DIR=${OIDC_FOLDER}/test_tokens +export AWS_WEB_IDENTITY_TOKEN_FILE=${OIDC_TOKEN_DIR}/test1 +export OIDC_TESTS_ENABLED=true + +echo "Configuring OIDC server for local authentication tests" + +cd ${OIDC_FOLDER} +DRIVERS_TOOLS=${DRIVERS_TOOLS} ./start_local_server.sh \ No newline at end of file diff --git a/bson/src/test/unit/util/ThreadTestHelpers.java b/bson/src/test/unit/util/ThreadTestHelpers.java index a4767c503f9..e2115da079f 100644 --- a/bson/src/test/unit/util/ThreadTestHelpers.java +++ b/bson/src/test/unit/util/ThreadTestHelpers.java @@ -31,15 +31,19 @@ private ThreadTestHelpers() { } public static void executeAll(final int nThreads, final Runnable c) { + executeAll(Collections.nCopies(nThreads, c).toArray(new Runnable[0])); + } + + public static void executeAll(final Runnable... runnables) { ExecutorService service = null; try { - service = Executors.newFixedThreadPool(nThreads); - CountDownLatch latch = new CountDownLatch(nThreads); + service = Executors.newFixedThreadPool(runnables.length); + CountDownLatch latch = new CountDownLatch(runnables.length); List failures = Collections.synchronizedList(new ArrayList<>()); - for (int i = 0; i < nThreads; i++) { + for (final Runnable runnable : runnables) { service.submit(() -> { try { - c.run(); + runnable.run(); } catch (Throwable e) { failures.add(e); } finally { diff --git a/driver-core/src/main/com/mongodb/AuthenticationMechanism.java b/driver-core/src/main/com/mongodb/AuthenticationMechanism.java index db8a909b79d..7a7b7415ef6 100644 --- a/driver-core/src/main/com/mongodb/AuthenticationMechanism.java +++ b/driver-core/src/main/com/mongodb/AuthenticationMechanism.java @@ -37,6 +37,13 @@ public enum AuthenticationMechanism { */ MONGODB_AWS("MONGODB-AWS"), + /** + * The MONGODB-OIDC mechanism. + * @since 4.10 + * @mongodb.server.release 7.0 + */ + MONGODB_OIDC("MONGODB-OIDC"), + /** * The MongoDB X.509 mechanism. This mechanism is available only with client certificates over SSL. */ diff --git a/driver-core/src/main/com/mongodb/ConnectionString.java b/driver-core/src/main/com/mongodb/ConnectionString.java index e715b8983f6..c5197b8b7d4 100644 --- a/driver-core/src/main/com/mongodb/ConnectionString.java +++ b/driver-core/src/main/com/mongodb/ConnectionString.java @@ -48,6 +48,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateCreateOidcCredential; import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; @@ -975,6 +976,10 @@ private MongoCredential createMongoCredentialWithMechanism(final AuthenticationM case MONGODB_AWS: credential = MongoCredential.createAwsCredential(userName, password); break; + case MONGODB_OIDC: + validateCreateOidcCredential(password); + credential = MongoCredential.createOidcCredential(userName); + break; default: throw new UnsupportedOperationException(format("The connection string contains an invalid authentication mechanism'. " + "'%s' is not a supported authentication mechanism", diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index ffa2a3c4e02..418863dc21c 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -17,22 +17,27 @@ package com.mongodb; import com.mongodb.annotations.Beta; +import com.mongodb.annotations.Evolving; import com.mongodb.annotations.Immutable; import com.mongodb.lang.Nullable; +import java.time.Duration; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import static com.mongodb.AuthenticationMechanism.GSSAPI; import static com.mongodb.AuthenticationMechanism.MONGODB_AWS; +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; import static com.mongodb.AuthenticationMechanism.MONGODB_X509; import static com.mongodb.AuthenticationMechanism.PLAIN; import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_1; import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_256; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateOidcCredentialConstruction; /** * Represents credentials to authenticate to a mongo server,as well as the source of the credentials and the authentication mechanism to @@ -179,6 +184,70 @@ public final class MongoCredential { @Beta(Beta.Reason.CLIENT) public static final String AWS_CREDENTIAL_PROVIDER_KEY = "AWS_CREDENTIAL_PROVIDER"; + /** + * The provider name. The value must be a string. + *

+ * If this is provided, + * {@link MongoCredential#REQUEST_TOKEN_CALLBACK_KEY} and + * {@link MongoCredential#REFRESH_TOKEN_CALLBACK_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String PROVIDER_NAME_KEY = "PROVIDER_NAME"; + + /** + * This callback is invoked when the OIDC-based authenticator requests + * tokens from the identity provider. The type of the value must be + * {@link OidcRequestCallback}. + *

+ * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String REQUEST_TOKEN_CALLBACK_KEY = "REQUEST_TOKEN_CALLBACK"; + + /** + * Mechanism key for invoked when the OIDC-based authenticator refreshes + * tokens from the identity provider. If this callback is not provided, + * then refresh operations will not be attempted.The type of the value + * must be {@link OidcRefreshCallback}. + *

+ * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String REFRESH_TOKEN_CALLBACK_KEY = "REFRESH_TOKEN_CALLBACK"; + + /** + * Mechanism key for a list of allowed hostnames or ip-addresses for MongoDB connections. Ports must be excluded. + * The hostnames may include a leading "*." wildcard, which allows for matching (potentially nested) subdomains. + * When MONGODB-OIDC authentication is attempted against a hostname that does not match any of list of allowed hosts + * the driver will raise an error. The type of the value must be {@code List}. + * + * @see MongoCredential#DEFAULT_ALLOWED_HOSTS + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String ALLOWED_HOSTS_KEY = "ALLOWED_HOSTS"; + + /** + * The list of allowed hosts that will be used if no + * {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied. + * The default allowed hosts are: + * {@code "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"} + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final List DEFAULT_ALLOWED_HOSTS = Collections.unmodifiableList(Arrays.asList( + "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")); + /** * Creates a MongoCredential instance with an unspecified mechanism. The client will negotiate the best mechanism based on the * version of the server that the client is authenticating to. @@ -327,6 +396,23 @@ public static MongoCredential createAwsCredential(@Nullable final String userNam return new MongoCredential(MONGODB_AWS, userName, "$external", password); } + /** + * Creates a MongoCredential instance for the MONGODB-OIDC mechanism. + * + * @param userName the user name, which may be null. This is the OIDC principal name. + * @return the credential + * @since 4.10 + * @see #withMechanismProperty(String, Object) + * @see #PROVIDER_NAME_KEY + * @see #REQUEST_TOKEN_CALLBACK_KEY + * @see #REFRESH_TOKEN_CALLBACK_KEY + * @see #ALLOWED_HOSTS_KEY + * @mongodb.server.release 7.0 + */ + public static MongoCredential createOidcCredential(@Nullable final String userName) { + return new MongoCredential(MONGODB_OIDC, userName, "$external", null); + } + /** * Creates a new MongoCredential as a copy of this instance, with the specified mechanism property added. * @@ -370,7 +456,11 @@ public MongoCredential withMechanism(final AuthenticationMechanism mechanism) { MongoCredential(@Nullable final AuthenticationMechanism mechanism, @Nullable final String userName, final String source, @Nullable final char[] password, final Map mechanismProperties) { - if (userName == null && !Arrays.asList(MONGODB_X509, MONGODB_AWS).contains(mechanism)) { + if (mechanism == MONGODB_OIDC) { + validateOidcCredentialConstruction(source, mechanismProperties); + } + + if (userName == null && !Arrays.asList(MONGODB_X509, MONGODB_AWS, MONGODB_OIDC).contains(mechanism)) { throw new IllegalArgumentException("username can not be null"); } @@ -543,4 +633,136 @@ public String toString() { + ", mechanismProperties=" + '}'; } + + /** + * The context for the {@link OidcRequestCallback#onRequest(OidcRequestContext) OIDC request callback}. + */ + @Evolving + public interface OidcRequestContext { + /** + * @return The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + IdpInfo getIdpInfo(); + + /** + * @return The timeout that this callback must complete within. + */ + Duration getTimeout(); + } + + /** + * The context for the {@link OidcRefreshCallback#onRefresh(OidcRefreshContext) OIDC refresh callback}. + */ + @Evolving + public interface OidcRefreshContext extends OidcRequestContext { + /** + * @return The OIDC Refresh token supplied by a prior callback invocation. + */ + String getRefreshToken(); + } + + /** + * This callback is invoked when the OIDC-based authenticator requests + * tokens from the identity provider. + *

+ * It does not have to be thread-safe, unless it is provided to multiple + * MongoClients. + */ + public interface OidcRequestCallback { + /** + * @param context The context. + * @return The response produced by an OIDC Identity Provider + */ + IdpResponse onRequest(OidcRequestContext context); + } + + /** + * This callback is invoked when the OIDC-based authenticator refreshes + * tokens from the identity provider. If this callback is not provided, + * then refresh operations will not be attempted. + *

+ * It does not have to be thread-safe, unless it is provided to multiple + * MongoClients. + */ + public interface OidcRefreshCallback { + /** + * @param context The context. + * @return The response produced by an OIDC Identity Provider + */ + IdpResponse onRefresh(OidcRefreshContext context); + } + + /** + * The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + @Evolving + public interface IdpInfo { + /** + * @return URL which describes the Authorization Server. This identifier is the + * iss of provided access tokens, and is viable for RFC8414 metadata + * discovery and RFC9207 identification. + */ + String getIssuer(); + + /** + * @return Unique client ID for this OIDC client. + */ + String getClientId(); + + /** + * @return Additional scopes to request from Identity Provider. Immutable. + */ + List getRequestScopes(); + } + + /** + * The response produced by an OIDC Identity Provider. + */ + public static final class IdpResponse { + + private final String accessToken; + + @Nullable + private final Integer accessTokenExpiresInSeconds; + + @Nullable + private final String refreshToken; + + /** + * @param accessToken The OIDC access token + * @param accessTokenExpiresInSeconds The expiration in seconds. If null, the access token is single-use. + * @param refreshToken The refresh token. If null, refresh will not be attempted. + */ + public IdpResponse(final String accessToken, @Nullable final Integer accessTokenExpiresInSeconds, + @Nullable final String refreshToken) { + notNull("accessToken", accessToken); + this.accessToken = accessToken; + this.accessTokenExpiresInSeconds = accessTokenExpiresInSeconds; + this.refreshToken = refreshToken; + } + + /** + * @return The OIDC access token. + */ + public String getAccessToken() { + return accessToken; + } + + /** + * @return The expiration time for the access token in seconds. + * If null, the access token is single-use. + */ + @Nullable + public Integer getAccessTokenExpiresInSeconds() { + return accessTokenExpiresInSeconds; + } + + /** + * @return The OIDC refresh token. If null, refresh will not be attempted. + */ + @Nullable + public String getRefreshToken() { + return refreshToken; + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index 51583bfd56f..f727caf20f0 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -20,6 +20,7 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.locks.StampedLock; import java.util.function.Supplier; import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; @@ -35,6 +36,21 @@ public static void withLock(final Lock lock, final Runnable action) { }); } + public static V withLock(final StampedLock lock, final Supplier supplier) { + long stamp; + try { + stamp = lock.writeLockInterruptibly(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new MongoInterruptedException("Interrupted waiting for lock", e); + } + try { + return supplier.get(); + } finally { + lock.unlockWrite(stamp); + } + } + public static V withLock(final Lock lock, final Supplier supplier) { return checkedWithLock(lock, supplier::get); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index 9ec4780d958..45e0b078452 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -21,6 +21,7 @@ import com.mongodb.ServerApi; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerType; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; @@ -42,6 +43,11 @@ public abstract class Authenticator { this.serverApi = serverApi; } + public static boolean shouldAuthenticate(@Nullable final Authenticator authenticator, + final ConnectionDescription connectionDescription) { + return authenticator != null && connectionDescription.getServerType() != ServerType.REPLICA_SET_ARBITER; + } + @NonNull MongoCredentialWithCache getMongoCredentialWithCache() { return credential; @@ -93,4 +99,9 @@ T getNonNullMechanismProperty(final String key, @Nullable final T defaultVal abstract void authenticateAsync(InternalConnection connection, ConnectionDescription connectionDescription, SingleResultCallback callback); + + public void reauthenticate(final InternalConnection connection) { + authenticate(connection, connection.getDescription()); + } + } diff --git a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java index ec0fc3f9c8f..35f9f8120ee 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java @@ -16,7 +16,6 @@ package com.mongodb.internal.connection; -import com.mongodb.AuthenticationMechanism; import com.mongodb.AwsCredential; import com.mongodb.MongoClientException; import com.mongodb.MongoCredential; @@ -27,14 +26,10 @@ import com.mongodb.internal.authentication.AwsCredentialHelper; import com.mongodb.lang.Nullable; import org.bson.BsonBinary; -import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; import org.bson.RawBsonDocument; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.EncoderContext; -import org.bson.io.BasicOutputBuffer; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; @@ -77,27 +72,12 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { return new AwsSaslClient(getMongoCredential()); } - private static class AwsSaslClient implements SaslClient { - private final MongoCredential credential; + private static class AwsSaslClient extends SaslClientImpl { private final byte[] clientNonce = new byte[RANDOM_LENGTH]; private int step = -1; AwsSaslClient(final MongoCredential credential) { - this.credential = credential; - } - - @Override - public String getMechanismName() { - AuthenticationMechanism authMechanism = credential.getAuthenticationMechanism(); - if (authMechanism == null) { - throw new IllegalArgumentException("Authentication mechanism cannot be null"); - } - return authMechanism.getMechanismName(); - } - - @Override - public boolean hasInitialResponse() { - return true; + super(credential); } @Override @@ -117,26 +97,6 @@ public boolean isComplete() { return step == 1; } - @Override - public byte[] unwrap(final byte[] bytes, final int i, final int i1) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public byte[] wrap(final byte[] bytes, final int i, final int i1) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public Object getNegotiatedProperty(final String s) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public void dispose() { - // nothing to do - } - private byte[] computeClientFirstMessage() { new SecureRandom().nextBytes(this.clientNonce); @@ -184,6 +144,7 @@ private byte[] computeClientFinalMessage(final byte[] serverFirst) throws SaslEx private AwsCredential createAwsCredential() { AwsCredential awsCredential; + MongoCredential credential = getCredential(); if (credential.getUserName() != null) { if (credential.getPassword() == null) { throw new MongoClientException("secretAccessKey is required for AWS credential"); @@ -207,13 +168,4 @@ private AwsCredential createAwsCredential() { return awsCredential; } } - - private static byte[] toBson(final BsonDocument document) { - byte[] bytes; - BasicOutputBuffer buffer = new BasicOutputBuffer(); - new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); - bytes = new byte[buffer.size()]; - System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); - return bytes; - } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index dec5a1d1977..143d66f2096 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -18,6 +18,7 @@ import com.mongodb.LoggerSettings; import com.mongodb.MongoClientException; +import com.mongodb.MongoCommandException; import com.mongodb.MongoCompressor; import com.mongodb.MongoException; import com.mongodb.MongoInternalException; @@ -64,6 +65,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.isTrue; @@ -92,6 +94,19 @@ @NotThreadSafe public class InternalStreamConnection implements InternalConnection { + private static volatile boolean recordEverything = false; + + /** + * Will attempt to record events to the command listener that are usually + * suppressed. + * + * @param recordEverything whether to attempt to record everything + */ + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + public static void setRecordEverything(final boolean recordEverything) { + InternalStreamConnection.recordEverything = recordEverything; + } + private static final Set SECURITY_SENSITIVE_COMMANDS = new HashSet<>(asList( "authenticate", "saslStart", @@ -111,6 +126,8 @@ public class InternalStreamConnection implements InternalConnection { private static final Logger LOGGER = Loggers.getLogger("connection"); private final ClusterConnectionMode clusterConnectionMode; + @Nullable + private final Authenticator authenticator; private final boolean isMonitoringConnection; private final ServerId serverId; private final ConnectionGenerationSupplier connectionGenerationSupplier; @@ -122,6 +139,7 @@ public class InternalStreamConnection implements InternalConnection { private final AtomicBoolean isClosed = new AtomicBoolean(); private final AtomicBoolean opened = new AtomicBoolean(); + private final AtomicBoolean authenticated = new AtomicBoolean(); private final List compressorList; private final LoggerSettings loggerSettings; @@ -147,17 +165,20 @@ public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMod final ConnectionGenerationSupplier connectionGenerationSupplier, final StreamFactory streamFactory, final List compressorList, final CommandListener commandListener, final InternalConnectionInitializer connectionInitializer) { - this(clusterConnectionMode, false, serverId, connectionGenerationSupplier, streamFactory, compressorList, + this(clusterConnectionMode, null, false, serverId, connectionGenerationSupplier, streamFactory, compressorList, LoggerSettings.builder().build(), commandListener, connectionInitializer); } - public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMode, final boolean isMonitoringConnection, + public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMode, + @Nullable final Authenticator authenticator, + final boolean isMonitoringConnection, final ServerId serverId, final ConnectionGenerationSupplier connectionGenerationSupplier, final StreamFactory streamFactory, final List compressorList, final LoggerSettings loggerSettings, final CommandListener commandListener, final InternalConnectionInitializer connectionInitializer) { this.clusterConnectionMode = clusterConnectionMode; + this.authenticator = authenticator; this.isMonitoringConnection = isMonitoringConnection; this.serverId = notNull("serverId", serverId); this.connectionGenerationSupplier = notNull("connectionGeneration", connectionGenerationSupplier); @@ -271,6 +292,7 @@ private void initAfterHandshakeFinish(final InternalConnectionInitializationDesc description = initializationDescription.getConnectionDescription(); initialServerDescription = initializationDescription.getServerDescription(); opened.set(true); + authenticated.set(true); sendCompressor = findSendCompressor(description); } @@ -336,8 +358,35 @@ public boolean isClosed() { @Override public T sendAndReceive(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext) { - CommandEventSender commandEventSender; + Supplier sendAndReceiveInternal = () -> sendAndReceiveInternal( + message, decoder, sessionContext, requestContext, operationContext); + try { + return sendAndReceiveInternal.get(); + } catch (MongoCommandException e) { + if (triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description)) { + authenticated.set(false); + authenticator.reauthenticate(this); + authenticated.set(true); + return sendAndReceiveInternal.get(); + } + throw e; + } + } + + public static boolean triggersReauthentication(@Nullable final Throwable t) { + if (t instanceof MongoCommandException) { + MongoCommandException e = (MongoCommandException) t; + return e.getErrorCode() == 391; + } + return false; + } + + @Nullable + private T sendAndReceiveInternal(final CommandMessage message, final Decoder decoder, + final SessionContext sessionContext, final RequestContext requestContext, + final OperationContext operationContext) { + CommandEventSender commandEventSender; try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this)) { message.encode(bsonOutput, sessionContext); commandEventSender = createCommandEventSender(message, bsonOutput, requestContext, operationContext); @@ -449,7 +498,7 @@ private T receiveCommandMessageResponse(final Decoder decoder, commandEventSender.sendFailedEvent(e); } throw e; - } + } } @Override @@ -839,12 +888,14 @@ public void onResult(@Nullable final ByteBuf result, @Nullable final Throwable t private CommandEventSender createCommandEventSender(final CommandMessage message, final ByteBufferBsonOutput bsonOutput, final RequestContext requestContext, final OperationContext operationContext) { - if (!isMonitoringConnection && opened() && (commandListener != null || COMMAND_PROTOCOL_LOGGER.isRequired(DEBUG, getClusterId()))) { - return new LoggingCommandEventSender(SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, - commandListener, requestContext, operationContext, message, bsonOutput, COMMAND_PROTOCOL_LOGGER, loggerSettings); - } else { + boolean listensOrLogs = commandListener != null || COMMAND_PROTOCOL_LOGGER.isRequired(DEBUG, getClusterId()); + if (!recordEverything && (isMonitoringConnection || !opened() || !authenticated.get() || !listensOrLogs)) { return new NoOpCommandEventSender(); } + return new LoggingCommandEventSender( + SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, + requestContext, operationContext, message, bsonOutput, + COMMAND_PROTOCOL_LOGGER, loggerSettings); } private ClusterId getClusterId() { diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java index 6cf2453c187..8b5c840c501 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java @@ -16,6 +16,7 @@ package com.mongodb.internal.connection; +import com.mongodb.AuthenticationMechanism; import com.mongodb.LoggerSettings; import com.mongodb.MongoCompressor; import com.mongodb.MongoDriverInformation; @@ -28,7 +29,6 @@ import java.util.List; -import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.connection.ClientMetadataHelper.createClientMetadataDocument; @@ -74,18 +74,21 @@ class InternalStreamConnectionFactory implements InternalConnectionFactory { @Override public InternalConnection create(final ServerId serverId, final ConnectionGenerationSupplier connectionGenerationSupplier) { Authenticator authenticator = credential == null ? null : createAuthenticator(credential); - return new InternalStreamConnection(clusterConnectionMode, isMonitoringConnection, serverId, connectionGenerationSupplier, + InternalStreamConnectionInitializer connectionInitializer = new InternalStreamConnectionInitializer( + clusterConnectionMode, authenticator, clientMetadataDocument, compressorList, serverApi); + return new InternalStreamConnection( + clusterConnectionMode, authenticator, + isMonitoringConnection, serverId, connectionGenerationSupplier, streamFactory, compressorList, loggerSettings, commandListener, - new InternalStreamConnectionInitializer(clusterConnectionMode, authenticator, clientMetadataDocument, compressorList, - serverApi)); + connectionInitializer); } private Authenticator createAuthenticator(final MongoCredentialWithCache credential) { - if (credential.getAuthenticationMechanism() == null) { + AuthenticationMechanism authenticationMechanism = credential.getAuthenticationMechanism(); + if (authenticationMechanism == null) { return new DefaultAuthenticator(credential, clusterConnectionMode, serverApi); } - - switch (assertNotNull(credential.getAuthenticationMechanism())) { + switch (authenticationMechanism) { case GSSAPI: return new GSSAPIAuthenticator(credential, clusterConnectionMode, serverApi); case PLAIN: @@ -97,8 +100,10 @@ private Authenticator createAuthenticator(final MongoCredentialWithCache credent return new ScramShaAuthenticator(credential, clusterConnectionMode, serverApi); case MONGODB_AWS: return new AwsAuthenticator(credential, clusterConnectionMode, serverApi); + case MONGODB_OIDC: + return new OidcAuthenticator(credential, clusterConnectionMode, serverApi); default: - throw new IllegalArgumentException("Unsupported authentication mechanism " + credential.getAuthenticationMechanism()); + throw new IllegalArgumentException("Unsupported authentication mechanism " + authenticationMechanism); } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index f3d77ff2b2d..d4858f3d973 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -25,7 +25,6 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.connection.ConnectionId; import com.mongodb.connection.ServerDescription; -import com.mongodb.connection.ServerType; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import org.bson.BsonArray; @@ -82,8 +81,10 @@ public InternalConnectionInitializationDescription finishHandshake(final Interna final InternalConnectionInitializationDescription description) { notNull("internalConnection", internalConnection); notNull("description", description); - - authenticate(internalConnection, description.getConnectionDescription()); + final ConnectionDescription connectionDescription = description.getConnectionDescription(); + if (Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { + authenticator.authenticate(internalConnection, connectionDescription); + } return completeConnectionDescriptionInitialization(internalConnection, description); } @@ -106,11 +107,12 @@ public void startHandshakeAsync(final InternalConnection internalConnection, public void finishHandshakeAsync(final InternalConnection internalConnection, final InternalConnectionInitializationDescription description, final SingleResultCallback callback) { - if (authenticator == null || description.getConnectionDescription().getServerType() - == ServerType.REPLICA_SET_ARBITER) { + ConnectionDescription connectionDescription = description.getConnectionDescription(); + + if (!Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { completeConnectionDescriptionInitializationAsync(internalConnection, description, callback); } else { - authenticator.authenticateAsync(internalConnection, description.getConnectionDescription(), + authenticator.authenticateAsync(internalConnection, connectionDescription, (result1, t1) -> { if (t1 != null) { callback.onResult(null, t1); @@ -201,12 +203,6 @@ private InternalConnectionInitializationDescription completeConnectionDescriptio description); } - private void authenticate(final InternalConnection internalConnection, final ConnectionDescription connectionDescription) { - if (authenticator != null && connectionDescription.getServerType() != ServerType.REPLICA_SET_ARBITER) { - authenticator.authenticate(internalConnection, connectionDescription); - } - } - private void setSpeculativeAuthenticateResponse(final BsonDocument helloResult) { if (authenticator instanceof SpeculativeAuthenticator) { ((SpeculativeAuthenticator) authenticator).setSpeculativeAuthenticateResponse( diff --git a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java index 43b9ad3eec5..3f3369059c3 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java +++ b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java @@ -22,8 +22,11 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.locks.StampedLock; import static com.mongodb.internal.Locks.withInterruptibleLock; +import static com.mongodb.internal.Locks.withLock; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcCacheEntry; /** *

This class is not part of the public API and may be removed or changed at any time

@@ -33,12 +36,12 @@ public class MongoCredentialWithCache { private final Cache cache; public MongoCredentialWithCache(final MongoCredential credential) { - this(credential, null); + this(credential, new Cache()); } - private MongoCredentialWithCache(final MongoCredential credential, @Nullable final Cache cache) { + private MongoCredentialWithCache(final MongoCredential credential, final Cache cache) { this.credential = credential; - this.cache = cache != null ? cache : new Cache(); + this.cache = cache; } public MongoCredentialWithCache withMechanism(final AuthenticationMechanism mechanism) { @@ -63,15 +66,34 @@ public void putInCache(final Object key, final Object value) { cache.set(key, value); } + OidcCacheEntry getOidcCacheEntry() { + return cache.oidcCacheEntry; + } + + void setOidcCacheEntry(final OidcCacheEntry oidcCacheEntry) { + this.cache.oidcCacheEntry = oidcCacheEntry; + } + + StampedLock getOidcLock() { + return cache.oidcLock; + } + public Lock getLock() { return cache.lock; } + /** + * Stores any state associated with the credential. + */ static class Cache { private final ReentrantLock lock = new ReentrantLock(); private Object cacheKey; private Object cacheValue; + + private final StampedLock oidcLock = new StampedLock(); + private volatile OidcCacheEntry oidcCacheEntry = new OidcCacheEntry(); + Object get(final Object key) { return withInterruptibleLock(lock, () -> { if (cacheKey != null && cacheKey.equals(key)) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java new file mode 100644 index 00000000000..f3c931a433f --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -0,0 +1,678 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.AuthenticationMechanism; +import com.mongodb.MongoClientException; +import com.mongodb.MongoCommandException; +import com.mongodb.MongoConfigurationException; +import com.mongodb.MongoCredential; +import com.mongodb.MongoCredential.IdpInfo; +import com.mongodb.MongoCredential.IdpResponse; +import com.mongodb.MongoException; +import com.mongodb.MongoSecurityException; +import com.mongodb.ServerAddress; +import com.mongodb.ServerApi; +import com.mongodb.connection.ClusterConnectionMode; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.internal.Locks; +import com.mongodb.internal.Timeout; +import com.mongodb.internal.VisibleForTesting; +import com.mongodb.lang.Nullable; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.RawBsonDocument; +import org.jetbrains.annotations.NotNull; + +import javax.security.sasl.SaslClient; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.DEFAULT_ALLOWED_HOSTS; +import static com.mongodb.MongoCredential.OidcRefreshCallback; +import static com.mongodb.MongoCredential.OidcRefreshContext; +import static com.mongodb.MongoCredential.OidcRequestCallback; +import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; +import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; +import static com.mongodb.assertions.Assertions.assertFalse; +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; +import static java.lang.String.format; + +/** + *

This class is not part of the public API and may be removed or changed at any time

+ */ +public final class OidcAuthenticator extends SaslAuthenticator { + + private static final List SUPPORTED_PROVIDERS = Arrays.asList("aws"); + + private static final Duration CALLBACK_TIMEOUT = Duration.ofMinutes(5); + + private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + + @Nullable + private ServerAddress serverAddress; + + @Nullable + private String connectionLastAccessToken; + + private FallbackState fallbackState = FallbackState.INITIAL; + + @Nullable + private BsonDocument speculativeAuthenticateResponse; + + @Nullable + private Function evaluateChallengeFunction; + + public OidcAuthenticator(final MongoCredentialWithCache credential, + final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { + super(credential, clusterConnectionMode, serverApi); + validateBeforeUse(credential.getCredential()); + + if (getMongoCredential().getAuthenticationMechanism() != MONGODB_OIDC) { + throw new MongoException("Incorrect mechanism: " + getMongoCredential().getMechanism()); + } + } + + @Override + public String getMechanismName() { + return MONGODB_OIDC.getMechanismName(); + } + + @Override + protected SaslClient createSaslClient(final ServerAddress serverAddress) { + this.serverAddress = serverAddress; + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + return new OidcSaslClient(mongoCredentialWithCache); + } + + @Override + @Nullable + public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) { + try { + if (isAutomaticAuthentication()) { + return wrapInSpeculative(prepareAwsTokenFromFileAsJwt()); + } + String cachedAccessToken = getValidCachedAccessToken(); + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + if (cachedAccessToken != null) { + return wrapInSpeculative(prepareTokenAsJwt(cachedAccessToken)); + } else if (mongoCredentialWithCache.getOidcCacheEntry().getIdpInfo() == null) { + String userName = mongoCredentialWithCache.getCredential().getUserName(); + return wrapInSpeculative(prepareUsername(userName)); + } else { + // otherwise, skip speculative auth + return null; + } + } catch (Exception e) { + throw wrapException(e); + } + } + + @NotNull + private BsonDocument wrapInSpeculative(final byte[] outToken) { + BsonDocument startDocument = createSaslStartCommandDocument(outToken) + .append("db", new BsonString(getMongoCredential().getSource())); + appendSaslStartOptions(startDocument); + return startDocument; + } + + @Override + @Nullable + public BsonDocument getSpeculativeAuthenticateResponse() { + BsonDocument response = speculativeAuthenticateResponse; + // response should only be read once + this.speculativeAuthenticateResponse = null; + if (response == null) { + this.connectionLastAccessToken = null; + } + return response; + } + + @Override + public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument response) { + speculativeAuthenticateResponse = response; + } + + @Nullable + private OidcRefreshCallback getRefreshCallback() { + return getMongoCredentialWithCache() + .getCredential() + .getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); + } + + @Nullable + private OidcRequestCallback getRequestCallback() { + return getMongoCredentialWithCache() + .getCredential() + .getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); + } + + @Override + public void reauthenticate(final InternalConnection connection) { + // method must only be called after original handshake: + assertTrue(connection.opened()); + authLock(connection, connection.getDescription()); + } + + @Override + public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { + // method must only be called during original handshake: + assertFalse(connection.opened()); + // this method "wraps" the default authentication method in custom OIDC retry logic + String accessToken = getValidCachedAccessToken(); + if (accessToken != null) { + try { + authenticateUsing(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); + } catch (MongoSecurityException e) { + if (triggersRetry(e)) { + authLock(connection, connectionDescription); + } else { + throw e; + } + } + } else { + authLock(connection, connectionDescription); + } + } + + private static boolean triggersRetry(@Nullable final Throwable t) { + if (t instanceof MongoSecurityException) { + MongoSecurityException e = (MongoSecurityException) t; + Throwable cause = e.getCause(); + if (cause instanceof MongoCommandException) { + MongoCommandException commandCause = (MongoCommandException) cause; + return commandCause.getErrorCode() == 18; + } + } + return false; + } + + private void authenticateUsing( + final InternalConnection connection, + final ConnectionDescription connectionDescription, + final Function evaluateChallengeFunction) { + this.evaluateChallengeFunction = evaluateChallengeFunction; + super.authenticate(connection, connectionDescription); + } + + private void authLock(final InternalConnection connection, final ConnectionDescription connectionDescription) { + fallbackState = FallbackState.INITIAL; + Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { + while (true) { + try { + authenticateUsing(connection, connectionDescription, (challenge) -> evaluate(challenge)); + break; + } catch (MongoSecurityException e) { + if (!(triggersRetry(e) && shouldRetryHandler())) { + throw e; + } + } + } + return null; + }); + } + + private byte[] evaluate(final byte[] challenge) { + if (isAutomaticAuthentication()) { + return prepareAwsTokenFromFileAsJwt(); + } + + OidcRequestCallback requestCallback = assertNotNull(getRequestCallback()); + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + String cachedAccessToken = getValidCachedAccessToken(); + String invalidConnectionAccessToken = connectionLastAccessToken; + String cachedRefreshToken = cacheEntry.getRefreshToken(); + IdpInfo cachedIdpInfo = cacheEntry.getIdpInfo(); + + if (cachedAccessToken != null) { + boolean cachedTokenIsInvalid = cachedAccessToken.equals(invalidConnectionAccessToken); + if (cachedTokenIsInvalid) { + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry.clearAccessToken()); + cachedAccessToken = null; + } + } + OidcRefreshCallback refreshCallback = getRefreshCallback(); + if (cachedAccessToken != null) { + fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; + return prepareTokenAsJwt(cachedAccessToken); + } else if (refreshCallback != null && cachedRefreshToken != null) { + assertNotNull(cachedIdpInfo); + // Invoke Refresh Callback using cached Refresh Token + validateAllowedHosts(getMongoCredential()); + fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; + IdpResponse result = refreshCallback.onRefresh(new OidcRefreshContextImpl( + cachedIdpInfo, cachedRefreshToken, CALLBACK_TIMEOUT)); + return populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); + } else { + // cache is empty + + /* + A check for present idp info short-circuits phase-3a. + + If a challenge is present, it can only be a response to a + "principal-request", so the challenge must be the resulting + idp info. Such a request is made during speculative auth, + though the source is unimportant, as long as we detect and + use it here. + + Checking that the fallback state is not phase-3a ensures that + this does not loop infinitely in the case of a bug. + */ + boolean idpInfoNotPresent = challenge.length == 0; + if (fallbackState != FallbackState.PHASE_3A_PRINCIPAL && idpInfoNotPresent) { + fallbackState = FallbackState.PHASE_3A_PRINCIPAL; + return prepareUsername(mongoCredentialWithCache.getCredential().getUserName()); + } else { + IdpInfo idpInfo = toIdpInfo(challenge); + validateAllowedHosts(getMongoCredential()); + IdpResponse result = requestCallback.onRequest(new OidcRequestContextImpl(idpInfo, CALLBACK_TIMEOUT)); + fallbackState = FallbackState.PHASE_3B_REQUEST_CALLBACK_TOKEN; + return populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); + } + } + } + + private boolean isAutomaticAuthentication() { + return getRequestCallback() == null; + } + + private boolean clientIsComplete() { + return fallbackState != FallbackState.PHASE_3A_PRINCIPAL; + } + + private boolean shouldRetryHandler() { + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { + // a cached access token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken()); + } else if (fallbackState == FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN) { + // a refresh token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + } else { + // a clean-restart failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + return false; + } + return true; + } + + @Nullable + private String getValidCachedAccessToken() { + return getMongoCredentialWithCache() + .getOidcCacheEntry() + .getValidCachedAccessToken(); + } + + static final class OidcCacheEntry { + @Nullable + private final String accessToken; + @Nullable + private final Timeout accessTokenExpiry; + @Nullable + private final String refreshToken; + @Nullable + private final IdpInfo idpInfo; + + @Override + public String toString() { + return "OidcCacheEntry{" + + "\n accessToken#hashCode='" + Objects.hashCode(accessToken) + '\'' + + ",\n accessTokenExpiry=" + accessTokenExpiry + + ",\n refreshToken='" + refreshToken + '\'' + + ",\n idpInfo=" + idpInfo + + '}'; + } + + OidcCacheEntry(final IdpInfo idpInfo, final IdpResponse idpResponse) { + Integer accessTokenExpiresInSeconds = idpResponse.getAccessTokenExpiresInSeconds(); + if (accessTokenExpiresInSeconds != null) { + this.accessToken = idpResponse.getAccessToken(); + long accessTokenExpiryReservedSeconds = TimeUnit.MINUTES.toSeconds(5); + this.accessTokenExpiry = Timeout.startNow( + Math.max(0, accessTokenExpiresInSeconds - accessTokenExpiryReservedSeconds), + TimeUnit.SECONDS); + } else { + this.accessToken = null; + this.accessTokenExpiry = null; + } + String refreshToken = idpResponse.getRefreshToken(); + if (refreshToken != null) { + this.refreshToken = refreshToken; + this.idpInfo = idpInfo; + } else { + this.refreshToken = null; + this.idpInfo = null; + } + } + + OidcCacheEntry() { + this(null, null, null, null); + } + + private OidcCacheEntry(@Nullable final String accessToken, @Nullable final Timeout accessTokenExpiry, + @Nullable final String refreshToken, @Nullable final IdpInfo idpInfo) { + this.accessToken = accessToken; + this.accessTokenExpiry = accessTokenExpiry; + this.refreshToken = refreshToken; + this.idpInfo = idpInfo; + } + + @Nullable + String getValidCachedAccessToken() { + if (accessToken == null || accessTokenExpiry == null || accessTokenExpiry.expired()) { + return null; + } + return accessToken; + } + + @Nullable + String getRefreshToken() { + return refreshToken; + } + + @Nullable + IdpInfo getIdpInfo() { + return idpInfo; + } + + OidcCacheEntry clearAccessToken() { + return new OidcCacheEntry( + null, + null, + this.refreshToken, + this.idpInfo); + } + + OidcCacheEntry clearRefreshToken() { + return new OidcCacheEntry( + this.accessToken, + this.accessTokenExpiry, + null, + null); + } + } + + private final class OidcSaslClient extends SaslClientImpl { + + private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache) { + super(mongoCredentialWithCache.getCredential()); + } + + @Override + public byte[] evaluateChallenge(final byte[] challenge) { + return assertNotNull(evaluateChallengeFunction).apply(challenge); + } + + @Override + public boolean isComplete() { + return clientIsComplete(); + } + + } + + private static String readAwsTokenFromFile() { + String path = System.getenv(AWS_WEB_IDENTITY_TOKEN_FILE); + if (path == null) { + throw new MongoClientException( + format("Environment variable must be specified: %s", AWS_WEB_IDENTITY_TOKEN_FILE)); + } + try { + return new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8); + } catch (IOException e) { + throw new MongoClientException(format( + "Could not read file specified by environment variable: %s at path: %s", + AWS_WEB_IDENTITY_TOKEN_FILE, path), e); + } + } + + private static byte[] prepareUsername(@Nullable final String username) { + BsonDocument document = new BsonDocument(); + if (username != null) { + document = document.append("n", new BsonString(username)); + } + return toBson(document); + } + + private byte[] populateCacheWithCallbackResultAndPrepareJwt( + final IdpInfo serverInfo, + @Nullable final IdpResponse idpResponse) { + if (idpResponse == null) { + throw new MongoConfigurationException("Result of callback must not be null"); + } + OidcCacheEntry newEntry = new OidcCacheEntry(serverInfo, idpResponse); + getMongoCredentialWithCache().setOidcCacheEntry(newEntry); + return prepareTokenAsJwt(idpResponse.getAccessToken()); + } + + private static IdpInfo toIdpInfo(final byte[] challenge) { + BsonDocument c = new RawBsonDocument(challenge); + String issuer = c.getString("issuer").getValue(); + String clientId = c.getString("clientId").getValue(); + return new IdpInfoImpl( + issuer, + clientId, + getStringArray(c, "requestScopes")); + } + + private void validateAllowedHosts(final MongoCredential credential) { + List allowedHosts = assertNotNull(credential.getMechanismProperty(ALLOWED_HOSTS_KEY, DEFAULT_ALLOWED_HOSTS)); + String host = assertNotNull(serverAddress).getHost(); + boolean permitted = allowedHosts.stream().anyMatch(allowedHost -> { + if (allowedHost.startsWith("*.")) { + String ending = allowedHost.substring(1); + return host.endsWith(ending); + } else if (allowedHost.contains("*")) { + throw new IllegalArgumentException( + "Allowed host " + allowedHost + " contains invalid wildcard"); + } else { + return host.equals(allowedHost); + } + }); + if (!permitted) { + throw new MongoSecurityException( + credential, "Host not permitted by " + ALLOWED_HOSTS_KEY + ": " + host); + } + } + + @Nullable + private static List getStringArray(final BsonDocument document, final String key) { + if (!document.isArray(key)) { + return null; + } + return document.getArray(key).stream() + // ignore non-string values from server, rather than error + .filter(v -> v.isString()) + .map(v -> v.asString().getValue()) + .collect(Collectors.toList()); + } + + private byte[] prepareTokenAsJwt(final String accessToken) { + connectionLastAccessToken = accessToken; + return toJwtDocument(accessToken); + } + + private static byte[] prepareAwsTokenFromFileAsJwt() { + String accessToken = readAwsTokenFromFile(); + return toJwtDocument(accessToken); + } + + private static byte[] toJwtDocument(final String accessToken) { + return toBson(new BsonDocument().append("jwt", new BsonString(accessToken))); + } + + /** + * Contains all validation logic for OIDC in one location + */ + public static final class OidcValidator { + private OidcValidator() { + } + + public static void validateOidcCredentialConstruction( + final String source, + final Map mechanismProperties) { + + if (!"$external".equals(source)) { + throw new IllegalArgumentException("source must be '$external'"); + } + + Object providerName = mechanismProperties.get(PROVIDER_NAME_KEY.toLowerCase()); + if (providerName != null) { + if (!(providerName instanceof String) || !SUPPORTED_PROVIDERS.contains(providerName)) { + throw new IllegalArgumentException(PROVIDER_NAME_KEY + " must be one of: " + SUPPORTED_PROVIDERS); + } + } + } + + public static void validateCreateOidcCredential(@Nullable final char[] password) { + if (password != null) { + throw new IllegalArgumentException("password must not be specified for " + + AuthenticationMechanism.MONGODB_OIDC); + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + public static void validateBeforeUse(final MongoCredential credential) { + String userName = credential.getUserName(); + Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); + Object requestCallback = credential.getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); + Object refreshCallback = credential.getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); + if (providerName == null) { + // callback + if (requestCallback == null) { + throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + " or " + + REQUEST_TOKEN_CALLBACK_KEY + " must be specified"); + } + } else { + // automatic + if (userName != null) { + throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + if (requestCallback != null) { + throw new IllegalArgumentException(REQUEST_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + if (refreshCallback != null) { + throw new IllegalArgumentException(REFRESH_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + } + } + } + + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static class OidcRequestContextImpl implements OidcRequestContext { + private final IdpInfo idpInfo; + private final Duration timeout; + + OidcRequestContextImpl(final IdpInfo idpInfo, final Duration timeout) { + this.idpInfo = assertNotNull(idpInfo); + this.timeout = assertNotNull(timeout); + } + + @Override + public IdpInfo getIdpInfo() { + return idpInfo; + } + + @Override + public Duration getTimeout() { + return timeout; + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static final class OidcRefreshContextImpl extends OidcRequestContextImpl + implements OidcRefreshContext { + private final String refreshToken; + + OidcRefreshContextImpl(final IdpInfo idpInfo, final String refreshToken, + final Duration timeout) { + super(idpInfo, timeout); + this.refreshToken = assertNotNull(refreshToken); + } + + @Override + public String getRefreshToken() { + return refreshToken; + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static final class IdpInfoImpl implements IdpInfo { + private final String issuer; + private final String clientId; + + private final List requestScopes; + + IdpInfoImpl(final String issuer, final String clientId, @Nullable final List requestScopes) { + this.issuer = assertNotNull(issuer); + this.clientId = assertNotNull(clientId); + this.requestScopes = requestScopes == null + ? Collections.emptyList() + : Collections.unmodifiableList(requestScopes); + } + + @Override + public String getIssuer() { + return issuer; + } + + @Override + public String getClientId() { + return clientId; + } + + @Override + public List getRequestScopes() { + return requestScopes; + } + } + + /** + * Represents what was sent in the last request to the MongoDB server. + */ + private enum FallbackState { + INITIAL, + PHASE_1_CACHED_TOKEN, + PHASE_2_REFRESH_CALLBACK_TOKEN, + PHASE_3A_PRINCIPAL, + PHASE_3B_REQUEST_CALLBACK_TOKEN + } +} diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index 2c2321fcbad..335dce38a57 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -16,6 +16,8 @@ package com.mongodb.internal.connection; +import com.mongodb.AuthenticationMechanism; +import com.mongodb.MongoCredential; import com.mongodb.MongoException; import com.mongodb.MongoInterruptedException; import com.mongodb.MongoSecurityException; @@ -30,9 +32,13 @@ import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; import org.bson.BsonBinary; +import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.EncoderContext; +import org.bson.io.BasicOutputBuffer; import javax.security.auth.Subject; import javax.security.auth.login.LoginException; @@ -55,6 +61,7 @@ abstract class SaslAuthenticator extends Authenticator implements SpeculativeAut super(credential, clusterConnectionMode, serverApi); } + @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { doAsSubject(() -> { SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress()); @@ -121,9 +128,11 @@ private void throwIfSaslClientIsNull(@Nullable final SaslClient saslClient) { } private BsonDocument getNextSaslResponse(final SaslClient saslClient, final InternalConnection connection) { - BsonDocument response = getSpeculativeAuthenticateResponse(); - if (response != null) { - return response; + if (!connection.opened()) { + BsonDocument response = getSpeculativeAuthenticateResponse(); + if (response != null) { + return response; + } } try { @@ -136,9 +145,9 @@ private BsonDocument getNextSaslResponse(final SaslClient saslClient, final Inte private void getNextSaslResponseAsync(final SaslClient saslClient, final InternalConnection connection, final SingleResultCallback callback) { - BsonDocument response = getSpeculativeAuthenticateResponse(); SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); try { + BsonDocument response = getSpeculativeAuthenticateResponse(); if (response == null) { byte[] serverResponse = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null); sendSaslStartAsync(serverResponse, connection, (result, t) -> { @@ -280,6 +289,15 @@ void doAsSubject(final java.security.PrivilegedAction action) { } } + static byte[] toBson(final BsonDocument document) { + byte[] bytes; + BasicOutputBuffer buffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); + bytes = new byte[buffer.size()]; + System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); + return bytes; + } + private final class Continuator implements SingleResultCallback { private final SaslClient saslClient; private final BsonDocument saslStartDocument; @@ -331,7 +349,51 @@ private void continueConversation(final BsonDocument result) { disposeOfSaslClient(saslClient); } } - } + protected abstract static class SaslClientImpl implements SaslClient { + private final MongoCredential credential; + + protected SaslClientImpl(final MongoCredential credential) { + this.credential = credential; + } + + @Override + public boolean hasInitialResponse() { + return true; + } + + @Override + public byte[] unwrap(final byte[] bytes, final int i, final int i1) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public byte[] wrap(final byte[] bytes, final int i, final int i1) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public Object getNegotiatedProperty(final String s) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public void dispose() { + // nothing to do + } + + @Override + public final String getMechanismName() { + AuthenticationMechanism authMechanism = getCredential().getAuthenticationMechanism(); + if (authMechanism == null) { + throw new IllegalArgumentException("Authentication mechanism cannot be null"); + } + return authMechanism.getMechanismName(); + } + + protected final MongoCredential getCredential() { + return credential; + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java index 5dec0d90c1e..02bc7912c93 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java @@ -92,7 +92,7 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { if (speculativeSaslClient != null) { return speculativeSaslClient; } - return new ScramShaSaslClient(getMongoCredentialWithCache(), randomStringGenerator, authenticationHashGenerator); + return new ScramShaSaslClient(getMongoCredentialWithCache().getCredential(), randomStringGenerator, authenticationHashGenerator); } @Override @@ -122,9 +122,7 @@ public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument resp } } - class ScramShaSaslClient implements SaslClient { - - private final MongoCredentialWithCache credential; + class ScramShaSaslClient extends SaslClientImpl { private final RandomStringGenerator randomStringGenerator; private final AuthenticationHashGenerator authenticationHashGenerator; private final String hAlgorithm; @@ -136,9 +134,11 @@ class ScramShaSaslClient implements SaslClient { private byte[] serverSignature; private int step = -1; - ScramShaSaslClient(final MongoCredentialWithCache credential, final RandomStringGenerator randomStringGenerator, - final AuthenticationHashGenerator authenticationHashGenerator) { - this.credential = credential; + ScramShaSaslClient( + final MongoCredential credential, + final RandomStringGenerator randomStringGenerator, + final AuthenticationHashGenerator authenticationHashGenerator) { + super(credential); this.randomStringGenerator = randomStringGenerator; this.authenticationHashGenerator = authenticationHashGenerator; if (assertNotNull(credential.getAuthenticationMechanism()).equals(SCRAM_SHA_1)) { @@ -150,14 +150,6 @@ class ScramShaSaslClient implements SaslClient { } } - public String getMechanismName() { - return assertNotNull(credential.getAuthenticationMechanism()).getMechanismName(); - } - - public boolean hasInitialResponse() { - return true; - } - public byte[] evaluateChallenge(final byte[] challenge) throws SaslException { step++; if (step == 0) { @@ -167,7 +159,8 @@ public byte[] evaluateChallenge(final byte[] challenge) throws SaslException { } else if (step == 2) { return validateServerSignature(challenge); } else { - throw new SaslException(format("Too many steps involved in the %s negotiation.", getMechanismName())); + throw new SaslException(format("Too many steps involved in the %s negotiation.", + super.getMechanismName())); } } @@ -184,22 +177,6 @@ public boolean isComplete() { return step == 2; } - public byte[] unwrap(final byte[] incoming, final int offset, final int len) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public byte[] wrap(final byte[] outgoing, final int offset, final int len) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public Object getNegotiatedProperty(final String propName) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public void dispose() { - // nothing to do - } - private byte[] computeClientFirstMessage() { clientNonce = randomStringGenerator.generate(RANDOM_LENGTH); String clientFirstMessage = "n=" + getUserName() + ",r=" + clientNonce; @@ -318,9 +295,8 @@ private HashMap parseServerResponse(final String response) { return map; } - private String getUserName() { - String userName = credential.getCredential().getUserName(); + String userName = getCredential().getUserName(); if (userName == null) { throw new IllegalArgumentException("Username can not be null"); } @@ -328,8 +304,8 @@ private String getUserName() { } private String getAuthenicationHash() { - String password = authenticationHashGenerator.generate(credential.getCredential()); - if (credential.getAuthenticationMechanism() == SCRAM_SHA_256) { + String password = authenticationHashGenerator.generate(getCredential()); + if (getCredential().getAuthenticationMechanism() == SCRAM_SHA_256) { password = SaslPrep.saslPrepStored(password); } return password; diff --git a/driver-core/src/test/functional/com/mongodb/client/TestHelper.java b/driver-core/src/test/functional/com/mongodb/client/TestHelper.java new file mode 100644 index 00000000000..237c03c7e19 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/TestHelper.java @@ -0,0 +1,47 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.lang.Nullable; + +import java.lang.reflect.Field; +import java.util.Map; + +import static java.lang.System.getenv; + +public final class TestHelper { + + public static void setEnvironmentVariable(final String name, @Nullable final String value) { + try { + Map env = getenv(); + Field field = env.getClass().getDeclaredField("m"); + field.setAccessible(true); + @SuppressWarnings("unchecked") + Map result = (Map) field.get(env); + if (value == null) { + result.remove(name); + } else { + result.put(name, value); + } + } catch (IllegalAccessException | NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + + private TestHelper() { + } +} diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java index 0a2838c2d55..c8274f382fc 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java @@ -17,11 +17,13 @@ package com.mongodb.internal.connection; import com.mongodb.MongoTimeoutException; +import com.mongodb.client.TestListener; import com.mongodb.event.CommandEvent; import com.mongodb.event.CommandFailedEvent; import com.mongodb.event.CommandListener; import com.mongodb.event.CommandStartedEvent; import com.mongodb.event.CommandSucceededEvent; +import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonDocumentWriter; import org.bson.BsonDouble; @@ -55,6 +57,8 @@ public class TestCommandListener implements CommandListener { private final List eventTypes; private final List ignoredCommandMonitoringEvents; private final List events = new ArrayList<>(); + @Nullable + private final TestListener listener; private final Lock lock = new ReentrantLock(); private final Condition commandCompletedCondition = lock.newCondition(); private final boolean observeSensitiveCommands; @@ -76,25 +80,44 @@ public Codec get(final Class clazz, final CodecRegistry registry) { }); } + /** + * When a test listener is set, this command listener will send string events to the + * test listener in the form {@code " "}, where the event + * type will be lowercase and will omit the terms "command" and "event". + * For example: {@code "saslContinue succeeded"}. + * + * @see InternalStreamConnection#setRecordEverything(boolean) + * @param listener the test listener + */ + public TestCommandListener(final TestListener listener) { + this(Arrays.asList("commandStartedEvent", "commandSucceededEvent", "commandFailedEvent"), emptyList(), true, listener); + } + public TestCommandListener() { this(Arrays.asList("commandStartedEvent", "commandSucceededEvent", "commandFailedEvent"), emptyList()); } public TestCommandListener(final List eventTypes, final List ignoredCommandMonitoringEvents) { - this(eventTypes, ignoredCommandMonitoringEvents, true); + this(eventTypes, ignoredCommandMonitoringEvents, true, null); } public TestCommandListener(final List eventTypes, final List ignoredCommandMonitoringEvents, - final boolean observeSensitiveCommands) { + final boolean observeSensitiveCommands, @Nullable final TestListener listener) { this.eventTypes = eventTypes; this.ignoredCommandMonitoringEvents = ignoredCommandMonitoringEvents; this.observeSensitiveCommands = observeSensitiveCommands; + this.listener = listener; } + + public void reset() { lock.lock(); try { events.clear(); + if (listener != null) { + listener.clear(); + } } finally { lock.unlock(); } @@ -109,6 +132,18 @@ public List getEvents() { } } + private void addEvent(final CommandEvent c) { + events.add(c); + String className = c.getClass().getSimpleName() + .replace("Command", "") + .replace("Event", "") + .toLowerCase(); + // example: "saslContinue succeeded" + if (listener != null) { + listener.add(c.getCommandName() + " " + className); + } + } + public CommandStartedEvent getCommandStartedEvent(final String commandName) { for (CommandEvent event : getCommandStartedEvents()) { if (event instanceof CommandStartedEvent) { @@ -226,7 +261,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(new CommandStartedEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), + addEvent(new CommandStartedEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), event.getConnectionDescription(), event.getDatabaseName(), event.getCommandName(), event.getCommand() == null ? null : getWritableClone(event.getCommand()))); } finally { @@ -249,7 +284,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(new CommandSucceededEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), + addEvent(new CommandSucceededEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), event.getConnectionDescription(), event.getDatabaseName(), event.getCommandName(), event.getResponse() == null ? null : event.getResponse().clone(), event.getElapsedTime(TimeUnit.NANOSECONDS))); @@ -274,7 +309,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(event); + addEvent(event); commandCompletedCondition.signal(); } finally { lock.unlock(); diff --git a/driver-core/src/test/resources/auth/connection-string.json b/driver-core/src/test/resources/auth/legacy/connection-string.json similarity index 74% rename from driver-core/src/test/resources/auth/connection-string.json rename to driver-core/src/test/resources/auth/legacy/connection-string.json index 2a37ae8df47..1d69685df10 100644 --- a/driver-core/src/test/resources/auth/connection-string.json +++ b/driver-core/src/test/resources/auth/legacy/connection-string.json @@ -444,6 +444,147 @@ "AWS_SESSION_TOKEN": "token!@#$%^&*()_+" } } + }, + { + "description": "should recognise the mechanism and request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest", "oidcRefresh"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true, + "REFRESH_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": "principalName", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "PROVIDER_NAME": "aws" + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "PROVIDER_NAME": "aws" + } + } + }, + { + "description": "should throw an exception if username and password are specified (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:gcp", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:unexisted", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRefresh"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if provider name and request callback are specified", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "callback": ["oidcRequest"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if provider name and refresh callback are specified", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "callback": ["oidcRefresh"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", + "valid": false, + "credential": null } ] -} +} \ No newline at end of file diff --git a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json new file mode 100644 index 00000000000..c99ebc6ece2 --- /dev/null +++ b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json @@ -0,0 +1,191 @@ +{ + "description": "reauthenticate_with_retry", + "schemaVersion": "1.12", + "runOnRequirements": [ + { + "minServerVersion": "6.3", + "auth": true + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": true, + "retryWrites": true + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "db", + "documents": [] + } + ], + "tests": [ + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": {} + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json new file mode 100644 index 00000000000..799057bf74f --- /dev/null +++ b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json @@ -0,0 +1,191 @@ +{ + "description": "reauthenticate_without_retry", + "schemaVersion": "1.12", + "runOnRequirements": [ + { + "minServerVersion": "6.3", + "auth": true + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": false, + "retryWrites": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "db", + "documents": [] + } + ], + "tests": [ + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": {} + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java index dfb81ba8de4..7f4acab857d 100644 --- a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java +++ b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java @@ -16,9 +16,13 @@ package com.mongodb; +import com.mongodb.internal.connection.OidcAuthenticator; +import com.mongodb.lang.Nullable; import junit.framework.TestCase; +import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonNull; +import org.bson.BsonString; import org.bson.BsonValue; import org.junit.Test; import org.junit.runner.RunWith; @@ -32,7 +36,11 @@ import java.util.Collection; import java.util.List; -// See https://github.com/mongodb/specifications/tree/master/source/auth/tests +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; +import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; + +// See https://github.com/mongodb/specifications/tree/master/source/auth/legacy/tests @RunWith(Parameterized.class) public class AuthConnectionStringTest extends TestCase { private final String input; @@ -56,7 +64,7 @@ public void shouldPassAllOutcomes() { @Parameterized.Parameters(name = "{1}") public static Collection data() throws URISyntaxException, IOException { List data = new ArrayList<>(); - for (File file : JsonPoweredTestHelper.getTestFiles("/auth")) { + for (File file : JsonPoweredTestHelper.getTestFiles("/auth/legacy")) { BsonDocument testDocument = JsonPoweredTestHelper.getTestDocument(file); for (BsonValue test : testDocument.getArray("tests")) { data.add(new Object[]{file.getName(), test.asDocument().getString("description").getValue(), @@ -69,7 +77,7 @@ public static Collection data() throws URISyntaxException, IOException private void testInvalidUris() { Throwable expectedError = null; try { - new ConnectionString(input).getCredential(); + getMongoCredential(); } catch (Throwable t) { expectedError = t; } @@ -78,7 +86,7 @@ private void testInvalidUris() { } private void testValidUris() { - MongoCredential credential = new ConnectionString(input).getCredential(); + MongoCredential credential = getMongoCredential(); if (credential != null) { assertString("credential.source", credential.getSource()); @@ -99,6 +107,36 @@ private void testValidUris() { } } + @Nullable + private MongoCredential getMongoCredential() { + ConnectionString connectionString; + connectionString = new ConnectionString(input); + MongoCredential credential = connectionString.getCredential(); + if (credential != null) { + BsonArray callbacks = (BsonArray) getExpectedValue("callback"); + if (callbacks != null) { + for (BsonValue v : callbacks) { + String string = ((BsonString) v).getValue(); + if ("oidcRequest".equals(string)) { + credential = credential.withMechanismProperty( + REQUEST_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRequestCallback) (context) -> null); + } else if ("oidcRefresh".equals(string)) { + credential = credential.withMechanismProperty( + REFRESH_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRefreshCallback) (context) -> null); + } else { + fail("Unsupported callback: " + string); + } + } + } + if (MONGODB_OIDC.getMechanismName().equals(credential.getMechanism())) { + OidcAuthenticator.OidcValidator.validateBeforeUse(credential); + } + } + return credential; + } + private void assertString(final String key, final String actual) { BsonValue expected = getExpectedValue(key); @@ -142,6 +180,14 @@ private void assertMechanismProperties(final MongoCredential credential) { } } else if ((document.get(key).isBoolean())) { boolean expectedValue = document.getBoolean(key).getValue(); + if (REQUEST_TOKEN_CALLBACK_KEY.equals(key)) { + assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRequestCallback); + return; + } + if (REFRESH_TOKEN_CALLBACK_KEY.equals(key)) { + assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRefreshCallback); + return; + } assertNotNull(actualMechanismProperty); assertEquals(expectedValue, actualMechanismProperty); } else { diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java index 4845ac460a1..773addf8767 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java @@ -391,8 +391,10 @@ private void initClient(final BsonDocument entity, final String id, .getArray("ignoreCommandMonitoringEvents", new BsonArray()).stream() .map(type -> type.asString().getValue()).collect(Collectors.toList()); ignoreCommandMonitoringEvents.add("configureFailPoint"); - TestCommandListener testCommandListener = new TestCommandListener(observeEvents, - ignoreCommandMonitoringEvents, entity.getBoolean("observeSensitiveCommands", BsonBoolean.FALSE).getValue()); + TestCommandListener testCommandListener = new TestCommandListener( + observeEvents, + ignoreCommandMonitoringEvents, entity.getBoolean("observeSensitiveCommands", BsonBoolean.FALSE).getValue(), + null); clientSettingsBuilder.addCommandListener(testCommandListener); putEntity(id + "-command-listener", testCommandListener, clientCommandListeners); diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java new file mode 100644 index 00000000000..f94977f2546 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.unified; + +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Collection; + +public class UnifiedAuthTest extends UnifiedSyncTest { + public UnifiedAuthTest(@SuppressWarnings("unused") final String fileDescription, + @SuppressWarnings("unused") final String testDescription, + final String schemaVersion, final BsonArray runOnRequirements, final BsonArray entitiesArray, + final BsonArray initialData, final BsonDocument definition) { + super(schemaVersion, runOnRequirements, entitiesArray, initialData, definition); + } + + @Parameterized.Parameters(name = "{0}: {1}") + public static Collection data() throws URISyntaxException, IOException { + return getTestData("unified-test-format/auth"); + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java new file mode 100644 index 00000000000..74e95d7e253 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -0,0 +1,937 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCommandException; +import com.mongodb.MongoConfigurationException; +import com.mongodb.MongoCredential; +import com.mongodb.MongoCredential.IdpResponse; +import com.mongodb.MongoCredential.OidcRefreshCallback; +import com.mongodb.MongoSecurityException; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.TestListener; +import com.mongodb.event.CommandListener; +import com.mongodb.lang.Nullable; +import org.bson.BsonArray; +import org.bson.BsonBoolean; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonString; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.opentest4j.AssertionFailedError; +import org.opentest4j.MultipleFailuresError; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.IdpInfo; +import static com.mongodb.MongoCredential.OidcRefreshContext; +import static com.mongodb.MongoCredential.OidcRequestCallback; +import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; +import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.createOidcCredential; +import static com.mongodb.client.TestHelper.setEnvironmentVariable; +import static java.lang.System.getenv; +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static util.ThreadTestHelpers.executeAll; + + +/** + * See + * Prose Tests. + */ +public class OidcAuthenticationProseTests { + + public static boolean oidcTestsEnabled() { + return Boolean.parseBoolean(getenv().get("OIDC_TESTS_ENABLED")); + } + + private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + + public static final String TOKEN_DIRECTORY = "/tmp/tokens/"; // TODO-OIDC + + protected static final String OIDC_URL = "mongodb://localhost/?authMechanism=MONGODB-OIDC"; + private static final String AWS_OIDC_URL = + "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws"; + private String appName; + + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return MongoClients.create(settings); + } + + protected void setOidcFile(final String file) { + setEnvironmentVariable(AWS_WEB_IDENTITY_TOKEN_FILE, TOKEN_DIRECTORY + file); + } + + @BeforeEach + public void beforeEach() { + assumeTrue(oidcTestsEnabled()); + // In each test, clearing the cache is not required, since there is no global cache + setOidcFile("test_user1"); + InternalStreamConnection.setRecordEverything(true); + this.appName = this.getClass().getSimpleName() + "-" + new Random().nextInt(Integer.MAX_VALUE); + } + + @AfterEach + public void afterEach() { + InternalStreamConnection.setRecordEverything(false); + } + + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // 1.1 to 1.5: + "test1p1 # test_user1 # " + OIDC_URL, + "test1p2 # test_user1 # mongodb://test_user1@localhost/?authMechanism=MONGODB-OIDC", + "test1p3 # test_user1 # mongodb://test_user1@localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", + "test1p4 # test_user2 # mongodb://test_user2@localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", + "test1p5 # invalid # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", + }) + public void test1CallbackDrivenAuth(final String name, final String file, final String url) { + boolean shouldPass = !file.equals("invalid"); + setOidcFile(file); + // #. Create a request callback that returns a valid token. + OidcRequestCallback onRequest = createCallback(); + // #. Create a client with a URL of the form ... and the OIDC request callback. + MongoClientSettings clientSettings = createSettings(url, onRequest, null); + // #. Perform a find operation that succeeds / fails + if (shouldPass) { + performFind(clientSettings); + } else { + performFind( + clientSettings, + MongoCommandException.class, + "Command failed with error 18 (AuthenticationFailed)"); + } + } + + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // 1.6, both variants: + "'' # " + OIDC_URL, + "example.com # mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com", + }) + public void test1p6CallbackDrivenAuthAllowedHostsBlocked(final String allowedHosts, final String url) { + // Create a client that uses the OIDC url and a request callback, and an ALLOWED_HOSTS that contains... + List allowedHostsList = asList(allowedHosts.split(",")); + MongoClientSettings settings = createSettings(url, createCallback(), null, allowedHostsList, null); + // #. Assert that a find operation fails with a client-side error. + performFind(settings, MongoSecurityException.class, ""); + } + + @Test + public void test1p7LockAvoidsExtraCallbackCalls() { + proveThatConcurrentCallbacksThrow(); + // The test requires that two operations are attempted concurrently. + // The delay on the next find should cause the initial request to delay + // and the ensuing refresh to block, rather than entering onRefresh. + // After blocking, this ensuing refresh thread will enter onRefresh. + AtomicInteger concurrent = new AtomicInteger(); + TestCallback onRequest = createCallback().setExpired().setConcurrentTracker(concurrent); + TestCallback onRefresh = createCallback().setConcurrentTracker(concurrent); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + } + } + + public void proveThatConcurrentCallbacksThrow() { + // ensure that, via delay, test callbacks throw when invoked concurrently + AtomicInteger c = new AtomicInteger(); + TestCallback request = createCallback().setConcurrentTracker(c).setDelayMs(5); + TestCallback refresh = createCallback().setConcurrentTracker(c); + IdpInfo serverInfo = new OidcAuthenticator.IdpInfoImpl("issuer", "clientId", asList()); + executeAll(() -> { + sleep(2); + assertThrows(RuntimeException.class, () -> { + refresh.onRefresh(new OidcAuthenticator.OidcRefreshContextImpl(serverInfo, "refToken", Duration.ofSeconds(1234))); + }); + }, () -> { + request.onRequest(new OidcAuthenticator.OidcRequestContextImpl(serverInfo, Duration.ofSeconds(1234))); + }); + } + + private void sleep(final long ms) { + try { + Thread.sleep(ms); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // 2.1 to 2.3: + "test2p1 # test_user1 # " + AWS_OIDC_URL, + "test2p2 # test_user1 # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws&directConnection=true&readPreference=secondaryPreferred", + "test2p3 # test_user2 # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws&directConnection=true&readPreference=secondaryPreferred", + }) + public void test2AwsAutomaticAuth(final String name, final String file, final String url) { + setOidcFile(file); + // #. Create a client with a url of the form ... + MongoCredential credential = createOidcCredential(null) + .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); + MongoClientSettings clientSettings = MongoClientSettings.builder() + .applicationName(appName) + .credential(credential) + .applyConnectionString(new ConnectionString(url)) + .build(); + // #. Perform a find operation that succeeds. + performFind(clientSettings); + } + + @Test + public void test2p4AllowedHostsIgnored() { + MongoClientSettings settings = createSettings( + AWS_OIDC_URL, null, null, Arrays.asList(), null); + performFind(settings); + } + + @Test + public void test3p1ValidCallbacks() { + String connectionString = "mongodb://test_user1@localhost/?authMechanism=MONGODB-OIDC"; + String expectedClientId = "0oadp0hpl7q3UIehP297"; + String expectedIssuer = "https://ebgxby0dw8.execute-api.us-west-1.amazonaws.com/default/mock-identity-config-oidc"; + Duration expectedSeconds = Duration.ofMinutes(5); + + TestCallback onRequest = createCallback().setExpired(); + TestCallback onRefresh = createCallback(); + // #. Verify that the request callback was called with the appropriate + // inputs, including the timeout parameter if possible. + // #. Verify that the refresh callback was called with the appropriate + // inputs, including the timeout parameter if possible. + OidcRequestCallback onRequest2 = (context) -> { + assertEquals(expectedClientId, context.getIdpInfo().getClientId()); + assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); + assertEquals(Arrays.asList(), context.getIdpInfo().getRequestScopes()); + assertEquals(expectedSeconds, context.getTimeout()); + return onRequest.onRequest(context); + }; + OidcRefreshCallback onRefresh2 = (context) -> { + assertEquals(expectedClientId, context.getIdpInfo().getClientId()); + assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); + assertEquals(Arrays.asList(), context.getIdpInfo().getRequestScopes()); + assertEquals(expectedSeconds, context.getTimeout()); + assertEquals("refreshToken", context.getRefreshToken()); + return onRefresh.onRefresh(context); + }; + MongoClientSettings clientSettings = createSettings(connectionString, onRequest2, onRefresh2); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + // Ensure that both callbacks were called + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + } + } + + @Test + public void test3p2RequestCallbackReturnsNull() { + //noinspection ConstantConditions + OidcRequestCallback onRequest = (context) -> null; + MongoClientSettings settings = this.createSettings(OIDC_URL, onRequest, null); + performFind(settings, MongoConfigurationException.class, "Result of callback must not be null"); + } + + @Test + public void test3p3RefreshCallbackReturnsNull() { + TestCallback onRequest = createCallback().setExpired().setDelayMs(100); + //noinspection ConstantConditions + OidcRefreshCallback onRefresh = (context) -> null; + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + try { + executeAll(2, () -> performFind(mongoClient)); + } catch (MultipleFailuresError actual) { + assertEquals(1, actual.getFailures().size()); + assertCause( + MongoConfigurationException.class, + "Result of callback must not be null", + actual.getFailures().get(0)); + } + assertEquals(1, onRequest.getInvocations()); + } + } + + @Test + public void test3p4RequestCallbackReturnsInvalidData() { + // #. Create a client with a request callback that returns data not + // conforming to the OIDCRequestTokenResult with missing field(s). + // #. ... with extra field(s). - not possible + OidcRequestCallback onRequest = (context) -> { + //noinspection ConstantConditions + return new IdpResponse(null, null, null); + }; + // we ensure that the error is propagated + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + try { + performFind(mongoClient); + fail(); + } catch (Exception e) { + assertCause(IllegalArgumentException.class, "accessToken can not be null", e); + } + } + } + + @Test + public void test3p5RefreshCallbackReturnsInvalidData() { + TestCallback onRequest = createCallback().setExpired(); + OidcRefreshCallback onRefresh = (context) -> { + //noinspection ConstantConditions + return new IdpResponse(null, null, null); + }; + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + try { + executeAll(2, () -> performFind(mongoClient)); + } catch (MultipleFailuresError actual) { + assertEquals(1, actual.getFailures().size()); + assertCause( + IllegalArgumentException.class, + "accessToken can not be null", + actual.getFailures().get(0)); + } + assertEquals(1, onRequest.getInvocations()); + } + } + + // 3.6 Refresh Callback Returns Extra Data - not possible due to use of class + + @Test + public void test4p1CachedCredentialsCacheWithRefresh() { + // #. Create a new client with a request callback that gives credentials that expire in one minute. + TestCallback onRequest = createCallback().setExpired(); + TestCallback onRefresh = createCallback(); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Create a new client with the same request callback and a refresh callback. + // Instead: + // 1. Delay the first find, causing the second find to authenticate a second connection + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + // #. Ensure that a find operation adds credentials to the cache. + // #. Ensure that a find operation results in a call to the refresh callback. + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + // the refresh invocation will fail if the cached tokens are null + // so a success implies that credentials were present in the cache + } + } + + @Test + public void test4p2CachedCredentialsCacheWithNoRefresh() { + // #. Create a new client with a request callback that gives credentials that expire in one minute. + // #. Ensure that a find operation adds credentials to the cache. + // #. Create a new client with a request callback but no refresh callback. + // #. Ensure that a find operation results in a call to the request callback. + TestCallback onRequest = createCallback().setExpired(); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + // test is the same as 4.1, but no onRefresh, and assert that the onRequest is called twice + assertEquals(2, onRequest.getInvocations()); + } + } + + // 4.3 Cache key includes callback - skipped: + // If the driver does not support using callback references or hashes as part of the cache key, skip this test. + + @Test + public void test4p4ErrorClearsCache() { + // #. Create a new client with a valid request callback that + // gives credentials that expire within 5 minutes and + // a refresh callback that gives invalid credentials. + + TestListener listener = new TestListener(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires", + "test_user1_1"); + TestCallback onRequest = createCallback() + .setExpired() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + TestCallback onRefresh = createCallback() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Ensure that a find operation adds a new entry to the cache. + performFind(mongoClient); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "onRequest invoked", + "read access token: test_user1", + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that a subsequent find operation results in a 391 error. + failCommand(391, 1, "find"); + // ensure that the operation entirely fails, after attempting both potential fallback callbacks + assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); + assertEquals(Arrays.asList( + "find started", + "find failed", + "onRefresh invoked", + "read access token: test_user1_expires", + "saslStart started", + "saslStart failed", + // falling back to principal request, onRequest callback. + "saslStart started", + "saslStart succeeded", + "onRequest invoked", + "read access token: test_user1_expires", + "saslContinue started", + "saslContinue failed" + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that the cache value cleared. + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(Arrays.asList( + "find started", + "find failed", + // falling back to principal request, onRequest callback. + // this implies that the cache has been cleared during the + // preceding find operation. + "saslStart started", + "saslStart succeeded", + "onRequest invoked", + "read access token: test_user1_1", + "saslContinue started", + "saslContinue succeeded", + // auth has finished + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + } + } + + // not a prose test. + @Test + public void testEventListenerMustNotLogReauthentication() { + InternalStreamConnection.setRecordEverything(false); + + TestListener listener = new TestListener(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires", + "test_user1_1"); + TestCallback onRequest = createCallback() + .setExpired() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + TestCallback onRefresh = createCallback() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + assertEquals(Arrays.asList( + "onRequest invoked", + "read access token: test_user1", + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + + failCommand(391, 1, "find"); + assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); + assertEquals(Arrays.asList( + "find started", + "find failed", + "onRefresh invoked", + "read access token: test_user1_expires", + // falling back to principal request, onRequest callback + "onRequest invoked", + "read access token: test_user1_expires" + ), listener.getEventStrings()); + } + } + + @Test + public void test4p5AwsAutomaticWorkflowDoesNotUseCache() { + // #. Create a new client that uses the AWS automatic workflow. + // #. Ensure that a find operation does not add credentials to the cache. + setOidcFile("test_user1"); + MongoCredential credential = createOidcCredential(null) + .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); + ConnectionString connectionString = new ConnectionString(AWS_OIDC_URL); + MongoClientSettings clientSettings = MongoClientSettings.builder() + .applicationName(appName) + .credential(credential) + .applyConnectionString(connectionString) + .build(); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + // This ensures that the next find failure results in a file (rather than cache) read + failCommand(391, 1, "find"); + setOidcFile("invalid_file"); + assertCause(NoSuchFileException.class, "invalid_file", () -> performFind(mongoClient)); + } + } + + @Test + public void test5SpeculativeAuthentication() { + // #. We can only test the successful case, by verifying that saslStart is not called. + // #. Create a client with a request callback that returns a valid token that will not expire soon. + TestListener listener = new TestListener(); + TestCallback onRequest = createCallback().setEventListener(listener); + TestCommandListener commandListener = new TestCommandListener(listener); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // instead of setting failpoints for saslStart, we inspect events + delayNextFind(); + executeAll(2, () -> performFind(mongoClient)); + + List events = listener.getEventStrings(); + assertFalse(events.stream().anyMatch(e -> e.contains("saslStart"))); + // onRequest is 2-step, so we expect 2 continues + assertEquals(2, events.stream().filter(e -> e.contains("saslContinue started")).count()); + // confirm all commands are enabled + assertTrue(events.stream().anyMatch(e -> e.contains("isMaster started"))); + } + } + + // Not a prose test + @Test + public void testAutomaticAuthUsesSpeculative() { + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings settings = createSettings( + AWS_OIDC_URL, null, null, Arrays.asList(), commandListener); + try (MongoClient mongoClient = createMongoClient(settings)) { + // we use a listener instead of a failpoint + performFind(mongoClient); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + } + } + + @Test + public void test6p1ReauthenticationSucceeds() { + // #. Create request and refresh callbacks that return valid credentials that will not expire soon. + TestListener listener = new TestListener(); + TestCallback onRequest = createCallback().setEventListener(listener); + TestCallback onRefresh = createCallback().setEventListener(listener); + + // #. Create a client with the callbacks and an event listener capable of listening for SASL commands. + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + + // #. Perform a find operation that succeeds. + performFind(mongoClient); + + // #. Assert that the refresh callback has not been called. + assertEquals(0, onRefresh.getInvocations()); + + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "onRequest invoked", + "read access token: test_user1", + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + + // #. Clear the listener state if possible. + commandListener.reset(); + listener.clear(); + + // #. Force a reauthenication using a failCommand + failCommand(391, 1, "find"); + + // #. Perform another find operation that succeeds. + performFind(mongoClient); + + // #. Assert that the ordering of command started events is: find, find. + // #. Assert that the ordering of command succeeded events is: find. + // #. Assert that a find operation failed once during the command execution. + assertEquals(Arrays.asList( + "find started", + "find failed", + "onRefresh invoked", + "read access token: test_user1", + "saslStart started", + "saslStart succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + + // #. Assert that the refresh callback has been called once, if possible. + assertEquals(1, onRefresh.getInvocations()); + } + } + + @NotNull + private ConcurrentLinkedQueue tokenQueue(final String... queue) { + return Stream + .of(queue) + .map(v -> TOKEN_DIRECTORY + v) + .collect(Collectors.toCollection(ConcurrentLinkedQueue::new)); + } + + @Test + public void test6p2ReauthenticationRetriesAndSucceedsWithCache() { + // #. Create request and refresh callbacks that return valid credentials that will not expire soon. + TestCallback onRequest = createCallback(); + TestCallback onRefresh = createCallback(); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Perform a find operation that succeeds. + performFind(mongoClient); + // #. Force a reauthenication using a failCommand + failCommand(391, 1, "find"); + // #. Perform a find operation that succeeds. + performFind(mongoClient); + } + } + + // 6.3 Retries and Fails with no Cache + // Appears to be untestable, since it requires 391 failure on jwt (may be fixed in future spec) + + @Test + public void test6p4SeparateConnectionsAvoidExtraCallbackCalls() { + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_1"); + TestCallback onRequest = createCallback().setPathSupplier(() -> tokens.remove()); + TestCallback onRefresh = createCallback().setPathSupplier(() -> tokens.remove()); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Peform a find operation on each ... that succeeds. + delayNextFind(); + executeAll(2, () -> performFind(mongoClient)); + // #. Ensure that the request callback has been called once and the refresh callback has not been called. + assertEquals(1, onRequest.getInvocations()); + assertEquals(0, onRefresh.getInvocations()); + + failCommand(391, 2, "find"); + executeAll(2, () -> performFind(mongoClient)); + + // #. Ensure that the request callback has been called once and the refresh callback has been called once. + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + } + } + + public MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcRequestCallback onRequest, + @Nullable final OidcRefreshCallback onRefresh) { + return createSettings(connectionString, onRequest, onRefresh, null, null); + } + + private MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcRequestCallback onRequest, + @Nullable final OidcRefreshCallback onRefresh, + @Nullable final List allowedHosts, + @Nullable final CommandListener commandListener) { + ConnectionString cs = new ConnectionString(connectionString); + MongoCredential credential = cs.getCredential() + .withMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, onRequest) + .withMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, onRefresh) + .withMechanismProperty(ALLOWED_HOSTS_KEY, allowedHosts); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applicationName(appName) + .applyConnectionString(cs) + .credential(credential); + if (commandListener != null) { + builder.addCommandListener(commandListener); + } + return builder.build(); + } + + private void performFind(final MongoClientSettings settings) { + try (MongoClient mongoClient = createMongoClient(settings)) { + performFind(mongoClient); + } + } + + private void performFind( + final MongoClientSettings settings, + final Class expectedExceptionOrCause, + final String expectedMessage) { + try (MongoClient mongoClient = createMongoClient(settings)) { + assertCause(expectedExceptionOrCause, expectedMessage, () -> performFind(mongoClient)); + } + } + + private void performFind(final MongoClient mongoClient) { + mongoClient + .getDatabase("test") + .getCollection("test") + .find() + .first(); + } + + private static void assertCause( + final Class expectedCause, final String expectedMessageFragment, final Executable e) { + Throwable actualException = assertThrows(Throwable.class, e); + assertCause(expectedCause, expectedMessageFragment, actualException); + } + + private static void assertCause( + final Class expectedCause, final String expectedMessageFragment, final Throwable actualException) { + Throwable cause = actualException; + while (cause.getCause() != null) { + cause = cause.getCause(); + } + if (!expectedCause.isInstance(cause)) { + throw new AssertionFailedError("Unexpected cause", actualException); + } + if (!cause.getMessage().contains(expectedMessageFragment)) { + throw new AssertionFailedError("Unexpected message", actualException); + } + } + + protected void delayNextFind() { + try (MongoClient client = createMongoClient(createSettings(AWS_OIDC_URL, null, null))) { + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(1))) + .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) + .append("failCommands", new BsonArray(asList(new BsonString("find")))) + .append("blockConnection", new BsonBoolean(true)) + .append("blockTimeMS", new BsonInt32(100))); + client.getDatabase("admin").runCommand(failPointDocument); + } + } + + protected void failCommand(final int code, final int times, final String... commands) { + try (MongoClient mongoClient = createMongoClient(createSettings( + AWS_OIDC_URL, null, null))) { + List list = Arrays.stream(commands).map(c -> new BsonString(c)).collect(Collectors.toList()); + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(times))) + .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) + .append("failCommands", new BsonArray(list)) + .append("errorCode", new BsonInt32(code))); + mongoClient.getDatabase("admin").runCommand(failPointDocument); + } + } + + public static class TestCallback implements OidcRequestCallback, OidcRefreshCallback { + private final AtomicInteger invocations = new AtomicInteger(); + @Nullable + private final Integer expiresInSeconds; + @Nullable + private final Integer delayInMilliseconds; + @Nullable + private final AtomicInteger concurrentTracker; + @Nullable + private final TestListener testListener; + @Nullable + private final Supplier pathSupplier; + + public TestCallback() { + this(60 * 60, null, new AtomicInteger(), null, null); + } + + public TestCallback( + @Nullable final Integer expiresInSeconds, + @Nullable final Integer delayInMilliseconds, + @Nullable final AtomicInteger concurrentTracker, + @Nullable final TestListener testListener, + @Nullable final Supplier pathSupplier) { + this.expiresInSeconds = expiresInSeconds; + this.delayInMilliseconds = delayInMilliseconds; + this.concurrentTracker = concurrentTracker; + this.testListener = testListener; + this.pathSupplier = pathSupplier; + } + + public int getInvocations() { + return invocations.get(); + } + + @Override + public IdpResponse onRequest(final OidcRequestContext context) { + if (testListener != null) { + testListener.add("onRequest invoked"); + } + return callback(); + } + + @Override + public IdpResponse onRefresh(final OidcRefreshContext context) { + if (context.getRefreshToken() == null) { + throw new IllegalArgumentException("refreshToken was null"); + } + if (testListener != null) { + testListener.add("onRefresh invoked"); + } + return callback(); + } + + @NotNull + private IdpResponse callback() { + if (concurrentTracker != null) { + if (concurrentTracker.get() > 0) { + throw new RuntimeException("Callbacks should not be invoked by multiple threads."); + } + concurrentTracker.incrementAndGet(); + } + try { + invocations.incrementAndGet(); + Path path = Paths.get(pathSupplier == null + ? getenv(AWS_WEB_IDENTITY_TOKEN_FILE) + : pathSupplier.get()); + String accessToken; + try { + simulateDelay(); + accessToken = new String(Files.readAllBytes(path), StandardCharsets.UTF_8); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + String refreshToken = "refreshToken"; + if (testListener != null) { + testListener.add("read access token: " + path.getFileName()); + } + return new IdpResponse( + accessToken, + expiresInSeconds, + refreshToken); + } finally { + if (concurrentTracker != null) { + concurrentTracker.decrementAndGet(); + } + } + } + + private void simulateDelay() throws InterruptedException { + if (delayInMilliseconds != null) { + Thread.sleep(delayInMilliseconds); + } + } + + public TestCallback setExpiresInSeconds(final Integer expiresInSeconds) { + return new TestCallback( + expiresInSeconds, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + + public TestCallback setDelayMs(final int milliseconds) { + return new TestCallback( + this.expiresInSeconds, + milliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + + public TestCallback setConcurrentTracker(final AtomicInteger c) { + return new TestCallback( + this.expiresInSeconds, + this.delayInMilliseconds, + c, + this.testListener, + this.pathSupplier); + } + + public TestCallback setEventListener(final TestListener testListener) { + return new TestCallback( + this.expiresInSeconds, + this.delayInMilliseconds, + this.concurrentTracker, + testListener, + this.pathSupplier); + } + + public TestCallback setPathSupplier(final Supplier pathSupplier) { + return new TestCallback( + this.expiresInSeconds, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + pathSupplier); + } + + public TestCallback setExpired() { + return this.setExpiresInSeconds(60); + } + } + + public TestCallback createCallback() { + return new TestCallback(); + } +} From 5a2f145c846f015b38718038fb51f67ff0ff053c Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 28 Jun 2023 11:57:28 -0600 Subject: [PATCH 2/6] Implement OIDC auth for async (#1131) JAVA-4981 --- .../com/mongodb/assertions/Assertions.java | 36 -------- .../src/main/com/mongodb/internal/Locks.java | 22 ++++- .../internal/connection/Authenticator.java | 7 ++ .../connection/InternalConnection.java | 2 +- .../connection/InternalStreamConnection.java | 63 ++++++++++--- .../connection/MongoCredentialWithCache.java | 1 - .../connection/OidcAuthenticator.java | 91 +++++++++++++++---- .../connection/SaslAuthenticator.java | 10 +- .../OidcAuthenticationAsyncProseTests.java | 70 ++++++++++++++ .../OidcAuthenticationProseTests.java | 6 ++ 10 files changed, 228 insertions(+), 80 deletions(-) create mode 100644 driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java diff --git a/driver-core/src/main/com/mongodb/assertions/Assertions.java b/driver-core/src/main/com/mongodb/assertions/Assertions.java index ae30c179e85..9866c222c6d 100644 --- a/driver-core/src/main/com/mongodb/assertions/Assertions.java +++ b/driver-core/src/main/com/mongodb/assertions/Assertions.java @@ -17,7 +17,6 @@ package com.mongodb.assertions; -import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import java.util.Collection; @@ -79,25 +78,6 @@ public static Iterable notNullElements(final String name, final Iterable< return values; } - /** - * Throw IllegalArgumentException if the value is null. - * - * @param name the parameter name - * @param value the value that should not be null - * @param callback the callback that also is passed the exception if the value is null - * @param the value type - * @return the value - * @throws java.lang.IllegalArgumentException if value is null - */ - public static T notNull(final String name, final T value, final SingleResultCallback callback) { - if (value == null) { - IllegalArgumentException exception = new IllegalArgumentException(name + " can not be null"); - callback.completeExceptionally(exception); - throw exception; - } - return value; - } - /** * Throw IllegalStateException if the condition if false. * @@ -111,22 +91,6 @@ public static void isTrue(final String name, final boolean condition) { } } - /** - * Throw IllegalStateException if the condition if false. - * - * @param name the name of the state that is being checked - * @param condition the condition about the parameter to check - * @param callback the callback that also is passed the exception if the condition is not true - * @throws java.lang.IllegalStateException if the condition is false - */ - public static void isTrue(final String name, final boolean condition, final SingleResultCallback callback) { - if (!condition) { - IllegalStateException exception = new IllegalStateException("state should be: " + name); - callback.completeExceptionally(exception); - throw exception; - } - } - /** * Throw IllegalArgumentException if the condition if false. * diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index f727caf20f0..2a169f45c52 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -17,6 +17,8 @@ package com.mongodb.internal; import com.mongodb.MongoInterruptedException; +import com.mongodb.internal.async.AsyncRunnable; +import com.mongodb.internal.async.SingleResultCallback; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -36,7 +38,23 @@ public static void withLock(final Lock lock, final Runnable action) { }); } - public static V withLock(final StampedLock lock, final Supplier supplier) { + public static void withLockAsync(final StampedLock lock, final AsyncRunnable runnable, + final SingleResultCallback callback) { + long stamp; + try { + stamp = lock.writeLockInterruptibly(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + callback.onResult(null, new MongoInterruptedException("Interrupted waiting for lock", e)); + return; + } + + runnable.thenAlwaysRunAndFinish(() -> { + lock.unlockWrite(stamp); + }, callback); + } + + public static void withLock(final StampedLock lock, final Runnable runnable) { long stamp; try { stamp = lock.writeLockInterruptibly(); @@ -45,7 +63,7 @@ public static V withLock(final StampedLock lock, final Supplier supplier) throw new MongoInterruptedException("Interrupted waiting for lock", e); } try { - return supplier.get(); + runnable.run(); } finally { lock.unlockWrite(stamp); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index 45e0b078452..232eeb45049 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -27,6 +27,7 @@ import com.mongodb.lang.Nullable; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; /** *

This class is not part of the public API and may be removed or changed at any time

@@ -104,4 +105,10 @@ public void reauthenticate(final InternalConnection connection) { authenticate(connection, connection.getDescription()); } + public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { + beginAsync().thenRun((c) -> { + authenticateAsync(connection, connection.getDescription(), c); + }).finish(callback); + } + } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java index 405ef31f5cf..e2b0188572e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java @@ -49,7 +49,7 @@ public interface InternalConnection extends BufferProvider { ServerDescription getInitialServerDescription(); /** - * Opens the connection so its ready for use + * Opens the connection so its ready for use. Will perform a handshake. */ void open(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 143d66f2096..09f2ec6a845 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -42,6 +42,7 @@ import com.mongodb.event.CommandListener; import com.mongodb.internal.ResourceUtil; import com.mongodb.internal.VisibleForTesting; +import com.mongodb.internal.async.AsyncSupplier; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.diagnostics.logging.Logger; import com.mongodb.internal.diagnostics.logging.Loggers; @@ -68,9 +69,12 @@ import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertNull; import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback; +import static com.mongodb.internal.connection.Authenticator.shouldAuthenticate; import static com.mongodb.internal.connection.CommandHelper.HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO_LOWER; @@ -238,7 +242,7 @@ public void open() { @Override public void openAsync(final SingleResultCallback callback) { - isTrue("Open already called", stream == null, callback); + assertNull(stream); try { stream = streamFactory.create(serverId.getAddress()); stream.openAsync(new AsyncCompletionHandler() { @@ -364,17 +368,48 @@ public T sendAndReceive(final CommandMessage message, final Decoder decod try { return sendAndReceiveInternal.get(); } catch (MongoCommandException e) { - if (triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description)) { - authenticated.set(false); - authenticator.reauthenticate(this); - authenticated.set(true); - return sendAndReceiveInternal.get(); + if (reauthenticationIsTriggered(e)) { + return reauthenticateAndRetry(sendAndReceiveInternal); } throw e; } } - public static boolean triggersReauthentication(@Nullable final Throwable t) { + @Override + public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, + final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { + + AsyncSupplier sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal( + message, decoder, sessionContext, requestContext, operationContext, c); + beginAsync().thenSupply(c -> { + sendAndReceiveAsyncInternal.getAsync(c); + }).onErrorIf(e -> reauthenticationIsTriggered(e), c -> { + reauthenticateAndRetryAsync(sendAndReceiveAsyncInternal, c); + }).finish(callback); + } + + private T reauthenticateAndRetry(final Supplier operation) { + authenticated.set(false); + assertNotNull(authenticator).reauthenticate(this); + authenticated.set(true); + return operation.get(); + } + + private void reauthenticateAndRetryAsync(final AsyncSupplier operation, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + authenticated.set(false); + assertNotNull(authenticator).reauthenticateAsync(this, c); + }).thenSupply((c) -> { + authenticated.set(true); + operation.getAsync(c); + }).finish(callback); + } + + public boolean reauthenticationIsTriggered(@Nullable final Throwable t) { + if (!shouldAuthenticate(authenticator, this.description)) { + return false; + } if (t instanceof MongoCommandException) { MongoCommandException e = (MongoCommandException) t; return e.getErrorCode() == 391; @@ -501,11 +536,8 @@ private T receiveCommandMessageResponse(final Decoder decoder, } } - @Override - public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, + private void sendAndReceiveAsyncInternal(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); - if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); return; @@ -616,7 +648,7 @@ public void sendMessage(final List byteBuffers, final int lastRequestId @Override public ResponseBuffers receiveMessage(final int responseTo) { - notNull("stream is open", stream); + assertNotNull(stream); if (isClosed()) { throw new MongoSocketClosedException("Cannot read from a closed stream", getServerAddress()); } @@ -634,8 +666,9 @@ private ResponseBuffers receiveMessageWithAdditionalTimeout(final int additional } @Override - public void sendMessageAsync(final List byteBuffers, final int lastRequestId, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); + public void sendMessageAsync(final List byteBuffers, final int lastRequestId, + final SingleResultCallback callback) { + assertNotNull(stream); if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); @@ -667,7 +700,7 @@ public void failed(final Throwable t) { @Override public void receiveMessageAsync(final int responseTo, final SingleResultCallback callback) { - isTrue("stream is open", stream != null, callback); + assertNotNull(stream); if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); diff --git a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java index 3f3369059c3..682637bf9ed 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java +++ b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java @@ -25,7 +25,6 @@ import java.util.concurrent.locks.StampedLock; import static com.mongodb.internal.Locks.withInterruptibleLock; -import static com.mongodb.internal.Locks.withLock; import static com.mongodb.internal.connection.OidcAuthenticator.OidcCacheEntry; /** diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index f3c931a433f..2d3387e9216 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -30,8 +30,9 @@ import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.Locks; -import com.mongodb.internal.Timeout; +import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.VisibleForTesting; +import com.mongodb.internal.time.Timeout; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; @@ -66,6 +67,7 @@ import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; import static java.lang.String.format; @@ -180,32 +182,66 @@ private OidcRequestCallback getRequestCallback() { @Override public void reauthenticate(final InternalConnection connection) { - // method must only be called after original handshake: assertTrue(connection.opened()); authLock(connection, connection.getDescription()); } + @Override + public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + assertTrue(connection.opened()); + authLockAsync(connection, connection.getDescription(), c); + }).finish(callback); + } + @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { - // method must only be called during original handshake: assertFalse(connection.opened()); - // this method "wraps" the default authentication method in custom OIDC retry logic String accessToken = getValidCachedAccessToken(); if (accessToken != null) { - try { - authenticateUsing(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); - } catch (MongoSecurityException e) { - if (triggersRetry(e)) { - authLock(connection, connectionDescription); - } else { - throw e; - } - } + authenticateOptimistically(connection, connectionDescription, accessToken); } else { authLock(connection, connectionDescription); } } + @Override + void authenticateAsync(final InternalConnection connection, final ConnectionDescription connectionDescription, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + assertFalse(connection.opened()); + String accessToken = getValidCachedAccessToken(); + if (accessToken != null) { + authenticateOptimisticallyAsync(connection, connectionDescription, accessToken, c); + } else { + authLockAsync(connection, connectionDescription, c); + } + }).finish(callback); + } + + private void authenticateOptimistically(final InternalConnection connection, + final ConnectionDescription connectionDescription, final String accessToken) { + try { + authenticateUsingFunction(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); + } catch (MongoSecurityException e) { + if (triggersRetry(e)) { + authLock(connection, connectionDescription); + } else { + throw e; + } + } + } + + private void authenticateOptimisticallyAsync(final InternalConnection connection, + final ConnectionDescription connectionDescription, final String accessToken, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + authenticateUsingFunctionAsync(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken), c); + }).onErrorIf(e -> triggersRetry(e), c -> { + authLockAsync(connection, connectionDescription, c); + }).finish(callback); + } + private static boolean triggersRetry(@Nullable final Throwable t) { if (t instanceof MongoSecurityException) { MongoSecurityException e = (MongoSecurityException) t; @@ -218,7 +254,14 @@ private static boolean triggersRetry(@Nullable final Throwable t) { return false; } - private void authenticateUsing( + private void authenticateUsingFunctionAsync(final InternalConnection connection, + final ConnectionDescription connectionDescription, final Function evaluateChallengeFunction, + final SingleResultCallback callback) { + this.evaluateChallengeFunction = evaluateChallengeFunction; + super.authenticateAsync(connection, connectionDescription, callback); + } + + private void authenticateUsingFunction( final InternalConnection connection, final ConnectionDescription connectionDescription, final Function evaluateChallengeFunction) { @@ -226,23 +269,33 @@ private void authenticateUsing( super.authenticate(connection, connectionDescription); } - private void authLock(final InternalConnection connection, final ConnectionDescription connectionDescription) { + private void authLock(final InternalConnection connection, final ConnectionDescription description) { fallbackState = FallbackState.INITIAL; Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { while (true) { try { - authenticateUsing(connection, connectionDescription, (challenge) -> evaluate(challenge)); + authenticateUsingFunction(connection, description, (challenge) -> evaluate(challenge)); break; } catch (MongoSecurityException e) { - if (!(triggersRetry(e) && shouldRetryHandler())) { - throw e; + if (triggersRetry(e) && shouldRetryHandler()) { + continue; } + throw e; } } - return null; }); } + private void authLockAsync(final InternalConnection connection, final ConnectionDescription description, + final SingleResultCallback callback) { + fallbackState = FallbackState.INITIAL; + Locks.withLockAsync(getMongoCredentialWithCache().getOidcLock(), + beginAsync().thenRunRetryingWhile( + c -> authenticateUsingFunctionAsync(connection, description, (challenge) -> evaluate(challenge), c), + e -> triggersRetry(e) && shouldRetryHandler() + ), callback); + } + private byte[] evaluate(final byte[] challenge) { if (isAutomaticAuthentication()) { return prepareAwsTokenFromFileAsJwt(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index 335dce38a57..6e4bea55514 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -128,11 +128,9 @@ private void throwIfSaslClientIsNull(@Nullable final SaslClient saslClient) { } private BsonDocument getNextSaslResponse(final SaslClient saslClient, final InternalConnection connection) { - if (!connection.opened()) { - BsonDocument response = getSpeculativeAuthenticateResponse(); - if (response != null) { - return response; - } + BsonDocument response = connection.opened() ? null : getSpeculativeAuthenticateResponse(); + if (response != null) { + return response; } try { @@ -147,7 +145,7 @@ private void getNextSaslResponseAsync(final SaslClient saslClient, final Interna final SingleResultCallback callback) { SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); try { - BsonDocument response = getSpeculativeAuthenticateResponse(); + BsonDocument response = connection.opened() ? null : getSpeculativeAuthenticateResponse(); if (response == null) { byte[] serverResponse = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null); sendSaslStartAsync(serverResponse, connection, (result, t) -> { diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java new file mode 100644 index 00000000000..b18825e89a8 --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.reactivestreams.client.MongoClients; +import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; +import org.junit.jupiter.api.Test; +import reactivestreams.helpers.SubscriberHelpers; + +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static util.ThreadTestHelpers.executeAll; + +public class OidcAuthenticationAsyncProseTests extends OidcAuthenticationProseTests { + + @Override + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return new SyncMongoClient(MongoClients.create(settings)); + } + + @Test + public void testNonblockingCallbacks() { + // not a prose spec test + delayNextFind(); + + int simulatedDelayMs = 100; + TestCallback requestCallback = createCallback().setExpired().setDelayMs(simulatedDelayMs); + TestCallback refreshCallback = createCallback().setDelayMs(simulatedDelayMs); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, requestCallback, refreshCallback); + + try (com.mongodb.reactivestreams.client.MongoClient client = MongoClients.create(clientSettings)) { + executeAll(2, () -> { + SubscriberHelpers.OperationSubscriber subscriber = new SubscriberHelpers.OperationSubscriber<>(); + long t1 = System.nanoTime(); + client.getDatabase("test") + .getCollection("test") + .find() + .first() + .subscribe(subscriber); + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t1); + + assertTrue(elapsedMs < simulatedDelayMs); + subscriber.get(); + }); + + // ensure both callbacks have been tested + assertEquals(1, requestCallback.getInvocations()); + assertEquals(1, refreshCallback.getInvocations()); + } + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index 74e95d7e253..368e1342e1f 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -598,12 +598,16 @@ public void test6p1ReauthenticationSucceeds() { assertEquals(0, onRefresh.getInvocations()); assertEquals(Arrays.asList( + // speculative: "isMaster started", "isMaster succeeded", + // onRequest: "onRequest invoked", "read access token: test_user1", + // jwt from onRequest: "saslContinue started", "saslContinue succeeded", + // ensuing find: "find started", "find succeeded" ), listener.getEventStrings()); @@ -624,10 +628,12 @@ public void test6p1ReauthenticationSucceeds() { assertEquals(Arrays.asList( "find started", "find failed", + // find has triggered 391, and cleared the access token; fall back to refresh: "onRefresh invoked", "read access token: test_user1", "saslStart started", "saslStart succeeded", + // find retry succeeds: "find started", "find succeeded" ), listener.getEventStrings()); From c1830b7bdc2270001728bd40f3c2824279eae547 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 21 Feb 2024 15:31:30 -0700 Subject: [PATCH 3/6] Remove non-machine workflow (#1259) * Remove non-machine workflow * Update prose tests to remove refresh token, principal-request * Conform to latest spec; remove lock around server auth * Rebase fix (async API) * Apply suggestions from code review Co-authored-by: Valentin Kovalenko * PR fixes --------- Co-authored-by: Valentin Kovalenko --- .../src/main/com/mongodb/MongoCredential.java | 135 +--- .../connection/InternalStreamConnection.java | 2 +- .../connection/OidcAuthenticator.java | 435 ++--------- .../auth/legacy/connection-string.json | 103 +-- .../auth/mongodb-oidc-no-retry.json | 428 +++++++++++ .../auth/reauthenticate_with_retry.json | 191 ----- .../auth/reauthenticate_without_retry.json | 191 ----- .../com/mongodb/AuthConnectionStringTest.java | 15 +- .../OidcAuthenticationAsyncProseTests.java | 6 +- .../com/mongodb/client/unified/Entities.java | 37 + .../mongodb/client/unified/ErrorMatcher.java | 18 +- .../unified/RunOnRequirementsMatcher.java | 9 + .../mongodb/client/unified/UnifiedTest.java | 4 +- .../OidcAuthenticationProseTests.java | 702 +++--------------- 14 files changed, 706 insertions(+), 1570 deletions(-) create mode 100644 driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json delete mode 100644 driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json delete mode 100644 driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index 418863dc21c..4c10e1f640c 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -25,7 +25,6 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Objects; @@ -188,8 +187,7 @@ public final class MongoCredential { * The provider name. The value must be a string. *

* If this is provided, - * {@link MongoCredential#REQUEST_TOKEN_CALLBACK_KEY} and - * {@link MongoCredential#REFRESH_TOKEN_CALLBACK_KEY} + * {@link MongoCredential#OIDC_CALLBACK_KEY} * must not be provided. * * @see #createOidcCredential(String) @@ -208,45 +206,7 @@ public final class MongoCredential { * @see #createOidcCredential(String) * @since 4.10 */ - public static final String REQUEST_TOKEN_CALLBACK_KEY = "REQUEST_TOKEN_CALLBACK"; - - /** - * Mechanism key for invoked when the OIDC-based authenticator refreshes - * tokens from the identity provider. If this callback is not provided, - * then refresh operations will not be attempted.The type of the value - * must be {@link OidcRefreshCallback}. - *

- * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} - * must not be provided. - * - * @see #createOidcCredential(String) - * @since 4.10 - */ - public static final String REFRESH_TOKEN_CALLBACK_KEY = "REFRESH_TOKEN_CALLBACK"; - - /** - * Mechanism key for a list of allowed hostnames or ip-addresses for MongoDB connections. Ports must be excluded. - * The hostnames may include a leading "*." wildcard, which allows for matching (potentially nested) subdomains. - * When MONGODB-OIDC authentication is attempted against a hostname that does not match any of list of allowed hosts - * the driver will raise an error. The type of the value must be {@code List}. - * - * @see MongoCredential#DEFAULT_ALLOWED_HOSTS - * @see #createOidcCredential(String) - * @since 4.10 - */ - public static final String ALLOWED_HOSTS_KEY = "ALLOWED_HOSTS"; - - /** - * The list of allowed hosts that will be used if no - * {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied. - * The default allowed hosts are: - * {@code "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"} - * - * @see #createOidcCredential(String) - * @since 4.10 - */ - public static final List DEFAULT_ALLOWED_HOSTS = Collections.unmodifiableList(Arrays.asList( - "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")); + public static final String OIDC_CALLBACK_KEY = "OIDC_CALLBACK"; /** * Creates a MongoCredential instance with an unspecified mechanism. The client will negotiate the best mechanism based on the @@ -404,9 +364,7 @@ public static MongoCredential createAwsCredential(@Nullable final String userNam * @since 4.10 * @see #withMechanismProperty(String, Object) * @see #PROVIDER_NAME_KEY - * @see #REQUEST_TOKEN_CALLBACK_KEY - * @see #REFRESH_TOKEN_CALLBACK_KEY - * @see #ALLOWED_HOSTS_KEY + * @see #OIDC_CALLBACK_KEY * @mongodb.server.release 7.0 */ public static MongoCredential createOidcCredential(@Nullable final String userName) { @@ -639,26 +597,16 @@ public String toString() { */ @Evolving public interface OidcRequestContext { - /** - * @return The OIDC Identity Provider's configuration that can be used to acquire an Access Token. - */ - IdpInfo getIdpInfo(); /** * @return The timeout that this callback must complete within. */ Duration getTimeout(); - } - /** - * The context for the {@link OidcRefreshCallback#onRefresh(OidcRefreshContext) OIDC refresh callback}. - */ - @Evolving - public interface OidcRefreshContext extends OidcRequestContext { /** - * @return The OIDC Refresh token supplied by a prior callback invocation. + * @return The OIDC callback API version. Currently, version 1. */ - String getRefreshToken(); + int getVersion(); } /** @@ -673,72 +621,22 @@ public interface OidcRequestCallback { * @param context The context. * @return The response produced by an OIDC Identity Provider */ - IdpResponse onRequest(OidcRequestContext context); - } - - /** - * This callback is invoked when the OIDC-based authenticator refreshes - * tokens from the identity provider. If this callback is not provided, - * then refresh operations will not be attempted. - *

- * It does not have to be thread-safe, unless it is provided to multiple - * MongoClients. - */ - public interface OidcRefreshCallback { - /** - * @param context The context. - * @return The response produced by an OIDC Identity Provider - */ - IdpResponse onRefresh(OidcRefreshContext context); - } - - /** - * The OIDC Identity Provider's configuration that can be used to acquire an Access Token. - */ - @Evolving - public interface IdpInfo { - /** - * @return URL which describes the Authorization Server. This identifier is the - * iss of provided access tokens, and is viable for RFC8414 metadata - * discovery and RFC9207 identification. - */ - String getIssuer(); - - /** - * @return Unique client ID for this OIDC client. - */ - String getClientId(); - - /** - * @return Additional scopes to request from Identity Provider. Immutable. - */ - List getRequestScopes(); + RequestCallbackResult onRequest(OidcRequestContext context); } /** * The response produced by an OIDC Identity Provider. */ - public static final class IdpResponse { + public static final class RequestCallbackResult { private final String accessToken; - @Nullable - private final Integer accessTokenExpiresInSeconds; - - @Nullable - private final String refreshToken; - /** * @param accessToken The OIDC access token - * @param accessTokenExpiresInSeconds The expiration in seconds. If null, the access token is single-use. - * @param refreshToken The refresh token. If null, refresh will not be attempted. */ - public IdpResponse(final String accessToken, @Nullable final Integer accessTokenExpiresInSeconds, - @Nullable final String refreshToken) { + public RequestCallbackResult(final String accessToken) { notNull("accessToken", accessToken); this.accessToken = accessToken; - this.accessTokenExpiresInSeconds = accessTokenExpiresInSeconds; - this.refreshToken = refreshToken; } /** @@ -747,22 +645,5 @@ public IdpResponse(final String accessToken, @Nullable final Integer accessToken public String getAccessToken() { return accessToken; } - - /** - * @return The expiration time for the access token in seconds. - * If null, the access token is single-use. - */ - @Nullable - public Integer getAccessTokenExpiresInSeconds() { - return accessTokenExpiresInSeconds; - } - - /** - * @return The OIDC refresh token. If null, refresh will not be attempted. - */ - @Nullable - public String getRefreshToken() { - return refreshToken; - } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 09f2ec6a845..218835f083e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -383,7 +383,7 @@ public void sendAndReceiveAsync(final CommandMessage message, final Decoder< message, decoder, sessionContext, requestContext, operationContext, c); beginAsync().thenSupply(c -> { sendAndReceiveAsyncInternal.getAsync(c); - }).onErrorIf(e -> reauthenticationIsTriggered(e), c -> { + }).onErrorIf(e -> reauthenticationIsTriggered(e), (t, c) -> { reauthenticateAndRetryAsync(sendAndReceiveAsyncInternal, c); }).finish(callback); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 2d3387e9216..70f9682476c 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -21,8 +21,7 @@ import com.mongodb.MongoCommandException; import com.mongodb.MongoConfigurationException; import com.mongodb.MongoCredential; -import com.mongodb.MongoCredential.IdpInfo; -import com.mongodb.MongoCredential.IdpResponse; +import com.mongodb.MongoCredential.RequestCallbackResult; import com.mongodb.MongoException; import com.mongodb.MongoSecurityException; import com.mongodb.ServerAddress; @@ -30,14 +29,11 @@ import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.Locks; -import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.VisibleForTesting; -import com.mongodb.internal.time.Timeout; +import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; -import org.bson.RawBsonDocument; -import org.jetbrains.annotations.NotNull; import javax.security.sasl.SaslClient; import java.io.IOException; @@ -46,24 +42,14 @@ import java.nio.file.Paths; import java.time.Duration; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.concurrent.TimeUnit; -import java.util.function.Function; -import java.util.stream.Collectors; import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; -import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; -import static com.mongodb.MongoCredential.DEFAULT_ALLOWED_HOSTS; -import static com.mongodb.MongoCredential.OidcRefreshCallback; -import static com.mongodb.MongoCredential.OidcRefreshContext; import static com.mongodb.MongoCredential.OidcRequestCallback; import static com.mongodb.MongoCredential.OidcRequestContext; import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; -import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; -import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; @@ -80,10 +66,8 @@ public final class OidcAuthenticator extends SaslAuthenticator { private static final Duration CALLBACK_TIMEOUT = Duration.ofMinutes(5); - private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; - - @Nullable - private ServerAddress serverAddress; + public static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + private static final int CALLBACK_API_VERSION_NUMBER = 1; @Nullable private String connectionLastAccessToken; @@ -93,9 +77,6 @@ public final class OidcAuthenticator extends SaslAuthenticator { @Nullable private BsonDocument speculativeAuthenticateResponse; - @Nullable - private Function evaluateChallengeFunction; - public OidcAuthenticator(final MongoCredentialWithCache credential, final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { super(credential, clusterConnectionMode, serverApi); @@ -113,7 +94,6 @@ public String getMechanismName() { @Override protected SaslClient createSaslClient(final ServerAddress serverAddress) { - this.serverAddress = serverAddress; MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); return new OidcSaslClient(mongoCredentialWithCache); } @@ -125,13 +105,9 @@ public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnectio if (isAutomaticAuthentication()) { return wrapInSpeculative(prepareAwsTokenFromFileAsJwt()); } - String cachedAccessToken = getValidCachedAccessToken(); - MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + String cachedAccessToken = getCachedAccessToken(); if (cachedAccessToken != null) { return wrapInSpeculative(prepareTokenAsJwt(cachedAccessToken)); - } else if (mongoCredentialWithCache.getOidcCacheEntry().getIdpInfo() == null) { - String userName = mongoCredentialWithCache.getCredential().getUserName(); - return wrapInSpeculative(prepareUsername(userName)); } else { // otherwise, skip speculative auth return null; @@ -141,7 +117,6 @@ public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnectio } } - @NotNull private BsonDocument wrapInSpeculative(final byte[] outToken) { BsonDocument startDocument = createSaslStartCommandDocument(outToken) .append("db", new BsonString(getMongoCredential().getSource())); @@ -166,43 +141,31 @@ public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument resp speculativeAuthenticateResponse = response; } - @Nullable - private OidcRefreshCallback getRefreshCallback() { - return getMongoCredentialWithCache() - .getCredential() - .getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); - } - @Nullable private OidcRequestCallback getRequestCallback() { return getMongoCredentialWithCache() .getCredential() - .getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); + .getMechanismProperty(OIDC_CALLBACK_KEY, null); } @Override public void reauthenticate(final InternalConnection connection) { assertTrue(connection.opened()); - authLock(connection, connection.getDescription()); + authenticationLoop(connection, connection.getDescription()); } @Override public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { beginAsync().thenRun(c -> { assertTrue(connection.opened()); - authLockAsync(connection, connection.getDescription(), c); + authenticationLoopAsync(connection, connection.getDescription(), c); }).finish(callback); } @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { assertFalse(connection.opened()); - String accessToken = getValidCachedAccessToken(); - if (accessToken != null) { - authenticateOptimistically(connection, connectionDescription, accessToken); - } else { - authLock(connection, connectionDescription); - } + authenticationLoop(connection, connectionDescription); } @Override @@ -210,35 +173,7 @@ void authenticateAsync(final InternalConnection connection, final ConnectionDesc final SingleResultCallback callback) { beginAsync().thenRun(c -> { assertFalse(connection.opened()); - String accessToken = getValidCachedAccessToken(); - if (accessToken != null) { - authenticateOptimisticallyAsync(connection, connectionDescription, accessToken, c); - } else { - authLockAsync(connection, connectionDescription, c); - } - }).finish(callback); - } - - private void authenticateOptimistically(final InternalConnection connection, - final ConnectionDescription connectionDescription, final String accessToken) { - try { - authenticateUsingFunction(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); - } catch (MongoSecurityException e) { - if (triggersRetry(e)) { - authLock(connection, connectionDescription); - } else { - throw e; - } - } - } - - private void authenticateOptimisticallyAsync(final InternalConnection connection, - final ConnectionDescription connectionDescription, final String accessToken, - final SingleResultCallback callback) { - beginAsync().thenRun(c -> { - authenticateUsingFunctionAsync(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken), c); - }).onErrorIf(e -> triggersRetry(e), c -> { - authLockAsync(connection, connectionDescription, c); + authenticationLoopAsync(connection, connectionDescription, c); }).finish(callback); } @@ -254,60 +189,61 @@ private static boolean triggersRetry(@Nullable final Throwable t) { return false; } - private void authenticateUsingFunctionAsync(final InternalConnection connection, - final ConnectionDescription connectionDescription, final Function evaluateChallengeFunction, - final SingleResultCallback callback) { - this.evaluateChallengeFunction = evaluateChallengeFunction; - super.authenticateAsync(connection, connectionDescription, callback); - } - - private void authenticateUsingFunction( - final InternalConnection connection, - final ConnectionDescription connectionDescription, - final Function evaluateChallengeFunction) { - this.evaluateChallengeFunction = evaluateChallengeFunction; - super.authenticate(connection, connectionDescription); - } - - private void authLock(final InternalConnection connection, final ConnectionDescription description) { + private void authenticationLoop(final InternalConnection connection, final ConnectionDescription description) { fallbackState = FallbackState.INITIAL; - Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { - while (true) { - try { - authenticateUsingFunction(connection, description, (challenge) -> evaluate(challenge)); - break; - } catch (MongoSecurityException e) { - if (triggersRetry(e) && shouldRetryHandler()) { - continue; - } - throw e; + while (true) { + try { + super.authenticate(connection, description); + break; + } catch (MongoSecurityException e) { + if (triggersRetry(e) && shouldRetryHandler()) { + continue; } + throw e; } - }); + } } - private void authLockAsync(final InternalConnection connection, final ConnectionDescription description, + private void authenticationLoopAsync(final InternalConnection connection, final ConnectionDescription description, final SingleResultCallback callback) { fallbackState = FallbackState.INITIAL; - Locks.withLockAsync(getMongoCredentialWithCache().getOidcLock(), - beginAsync().thenRunRetryingWhile( - c -> authenticateUsingFunctionAsync(connection, description, (challenge) -> evaluate(challenge), c), - e -> triggersRetry(e) && shouldRetryHandler() - ), callback); + beginAsync().thenRunRetryingWhile( + c -> super.authenticateAsync(connection, description, c), + e -> triggersRetry(e) && shouldRetryHandler() + ).finish(callback); } private byte[] evaluate(final byte[] challenge) { if (isAutomaticAuthentication()) { return prepareAwsTokenFromFileAsJwt(); } + byte[][] jwt = new byte[1][]; + Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { + String cachedAccessToken = validatedCachedAccessToken(); - OidcRequestCallback requestCallback = assertNotNull(getRequestCallback()); + if (cachedAccessToken != null) { + jwt[0] = prepareTokenAsJwt(cachedAccessToken); + fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; + } else { + // cache is empty + OidcRequestCallback requestCallback = assertNotNull(getRequestCallback()); + RequestCallbackResult result = requestCallback.onRequest(new OidcRequestContextImpl(CALLBACK_TIMEOUT)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(result); + fallbackState = FallbackState.PHASE_2_CALLBACK_TOKEN; + } + }); + return jwt[0]; + } + + /** + * Must be guarded by {@link MongoCredentialWithCache#getOidcLock()}. + */ + @Nullable + private String validatedCachedAccessToken() { MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); - String cachedAccessToken = getValidCachedAccessToken(); + String cachedAccessToken = getCachedAccessToken(); String invalidConnectionAccessToken = connectionLastAccessToken; - String cachedRefreshToken = cacheEntry.getRefreshToken(); - IdpInfo cachedIdpInfo = cacheEntry.getIdpInfo(); if (cachedAccessToken != null) { boolean cachedTokenIsInvalid = cachedAccessToken.equals(invalidConnectionAccessToken); @@ -316,45 +252,7 @@ private byte[] evaluate(final byte[] challenge) { cachedAccessToken = null; } } - OidcRefreshCallback refreshCallback = getRefreshCallback(); - if (cachedAccessToken != null) { - fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; - return prepareTokenAsJwt(cachedAccessToken); - } else if (refreshCallback != null && cachedRefreshToken != null) { - assertNotNull(cachedIdpInfo); - // Invoke Refresh Callback using cached Refresh Token - validateAllowedHosts(getMongoCredential()); - fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; - IdpResponse result = refreshCallback.onRefresh(new OidcRefreshContextImpl( - cachedIdpInfo, cachedRefreshToken, CALLBACK_TIMEOUT)); - return populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); - } else { - // cache is empty - - /* - A check for present idp info short-circuits phase-3a. - - If a challenge is present, it can only be a response to a - "principal-request", so the challenge must be the resulting - idp info. Such a request is made during speculative auth, - though the source is unimportant, as long as we detect and - use it here. - - Checking that the fallback state is not phase-3a ensures that - this does not loop infinitely in the case of a bug. - */ - boolean idpInfoNotPresent = challenge.length == 0; - if (fallbackState != FallbackState.PHASE_3A_PRINCIPAL && idpInfoNotPresent) { - fallbackState = FallbackState.PHASE_3A_PRINCIPAL; - return prepareUsername(mongoCredentialWithCache.getCredential().getUserName()); - } else { - IdpInfo idpInfo = toIdpInfo(challenge); - validateAllowedHosts(getMongoCredential()); - IdpResponse result = requestCallback.onRequest(new OidcRequestContextImpl(idpInfo, CALLBACK_TIMEOUT)); - fallbackState = FallbackState.PHASE_3B_REQUEST_CALLBACK_TOKEN; - return populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); - } - } + return cachedAccessToken; } private boolean isAutomaticAuthentication() { @@ -362,124 +260,53 @@ private boolean isAutomaticAuthentication() { } private boolean clientIsComplete() { - return fallbackState != FallbackState.PHASE_3A_PRINCIPAL; + return true; // all possibilities are 1-step } private boolean shouldRetryHandler() { - MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); - OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); - if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { - // a cached access token failed - mongoCredentialWithCache.setOidcCacheEntry(cacheEntry - .clearAccessToken()); - } else if (fallbackState == FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN) { - // a refresh token failed - mongoCredentialWithCache.setOidcCacheEntry(cacheEntry - .clearAccessToken() - .clearRefreshToken()); - } else { - // a clean-restart failed - mongoCredentialWithCache.setOidcCacheEntry(cacheEntry - .clearAccessToken() - .clearRefreshToken()); - return false; - } - return true; + Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { + validatedCachedAccessToken(); + }); + return fallbackState == FallbackState.PHASE_1_CACHED_TOKEN; } @Nullable - private String getValidCachedAccessToken() { + private String getCachedAccessToken() { return getMongoCredentialWithCache() .getOidcCacheEntry() - .getValidCachedAccessToken(); + .getCachedAccessToken(); } static final class OidcCacheEntry { @Nullable private final String accessToken; - @Nullable - private final Timeout accessTokenExpiry; - @Nullable - private final String refreshToken; - @Nullable - private final IdpInfo idpInfo; @Override public String toString() { return "OidcCacheEntry{" - + "\n accessToken#hashCode='" + Objects.hashCode(accessToken) + '\'' - + ",\n accessTokenExpiry=" + accessTokenExpiry - + ",\n refreshToken='" + refreshToken + '\'' - + ",\n idpInfo=" + idpInfo + + "\n accessToken=[omitted]" + '}'; } - OidcCacheEntry(final IdpInfo idpInfo, final IdpResponse idpResponse) { - Integer accessTokenExpiresInSeconds = idpResponse.getAccessTokenExpiresInSeconds(); - if (accessTokenExpiresInSeconds != null) { - this.accessToken = idpResponse.getAccessToken(); - long accessTokenExpiryReservedSeconds = TimeUnit.MINUTES.toSeconds(5); - this.accessTokenExpiry = Timeout.startNow( - Math.max(0, accessTokenExpiresInSeconds - accessTokenExpiryReservedSeconds), - TimeUnit.SECONDS); - } else { - this.accessToken = null; - this.accessTokenExpiry = null; - } - String refreshToken = idpResponse.getRefreshToken(); - if (refreshToken != null) { - this.refreshToken = refreshToken; - this.idpInfo = idpInfo; - } else { - this.refreshToken = null; - this.idpInfo = null; - } + OidcCacheEntry(final RequestCallbackResult requestCallbackResult) { + this.accessToken = requestCallbackResult.getAccessToken(); } OidcCacheEntry() { - this(null, null, null, null); + this((String) null); } - private OidcCacheEntry(@Nullable final String accessToken, @Nullable final Timeout accessTokenExpiry, - @Nullable final String refreshToken, @Nullable final IdpInfo idpInfo) { + private OidcCacheEntry(@Nullable final String accessToken) { this.accessToken = accessToken; - this.accessTokenExpiry = accessTokenExpiry; - this.refreshToken = refreshToken; - this.idpInfo = idpInfo; } @Nullable - String getValidCachedAccessToken() { - if (accessToken == null || accessTokenExpiry == null || accessTokenExpiry.expired()) { - return null; - } + String getCachedAccessToken() { return accessToken; } - @Nullable - String getRefreshToken() { - return refreshToken; - } - - @Nullable - IdpInfo getIdpInfo() { - return idpInfo; - } - OidcCacheEntry clearAccessToken() { - return new OidcCacheEntry( - null, - null, - this.refreshToken, - this.idpInfo); - } - - OidcCacheEntry clearRefreshToken() { - return new OidcCacheEntry( - this.accessToken, - this.accessTokenExpiry, - null, - null); + return new OidcCacheEntry((String) null); } } @@ -491,7 +318,7 @@ private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache) @Override public byte[] evaluateChallenge(final byte[] challenge) { - return assertNotNull(evaluateChallengeFunction).apply(challenge); + return evaluate(challenge); } @Override @@ -516,65 +343,13 @@ private static String readAwsTokenFromFile() { } } - private static byte[] prepareUsername(@Nullable final String username) { - BsonDocument document = new BsonDocument(); - if (username != null) { - document = document.append("n", new BsonString(username)); - } - return toBson(document); - } - - private byte[] populateCacheWithCallbackResultAndPrepareJwt( - final IdpInfo serverInfo, - @Nullable final IdpResponse idpResponse) { - if (idpResponse == null) { + private byte[] populateCacheWithCallbackResultAndPrepareJwt(@Nullable final RequestCallbackResult requestCallbackResult) { + if (requestCallbackResult == null) { throw new MongoConfigurationException("Result of callback must not be null"); } - OidcCacheEntry newEntry = new OidcCacheEntry(serverInfo, idpResponse); + OidcCacheEntry newEntry = new OidcCacheEntry(requestCallbackResult); getMongoCredentialWithCache().setOidcCacheEntry(newEntry); - return prepareTokenAsJwt(idpResponse.getAccessToken()); - } - - private static IdpInfo toIdpInfo(final byte[] challenge) { - BsonDocument c = new RawBsonDocument(challenge); - String issuer = c.getString("issuer").getValue(); - String clientId = c.getString("clientId").getValue(); - return new IdpInfoImpl( - issuer, - clientId, - getStringArray(c, "requestScopes")); - } - - private void validateAllowedHosts(final MongoCredential credential) { - List allowedHosts = assertNotNull(credential.getMechanismProperty(ALLOWED_HOSTS_KEY, DEFAULT_ALLOWED_HOSTS)); - String host = assertNotNull(serverAddress).getHost(); - boolean permitted = allowedHosts.stream().anyMatch(allowedHost -> { - if (allowedHost.startsWith("*.")) { - String ending = allowedHost.substring(1); - return host.endsWith(ending); - } else if (allowedHost.contains("*")) { - throw new IllegalArgumentException( - "Allowed host " + allowedHost + " contains invalid wildcard"); - } else { - return host.equals(allowedHost); - } - }); - if (!permitted) { - throw new MongoSecurityException( - credential, "Host not permitted by " + ALLOWED_HOSTS_KEY + ": " + host); - } - } - - @Nullable - private static List getStringArray(final BsonDocument document, final String key) { - if (!document.isArray(key)) { - return null; - } - return document.getArray(key).stream() - // ignore non-string values from server, rather than error - .filter(v -> v.isString()) - .map(v -> v.asString().getValue()) - .collect(Collectors.toList()); + return prepareTokenAsJwt(requestCallbackResult.getAccessToken()); } private byte[] prepareTokenAsJwt(final String accessToken) { @@ -625,13 +400,12 @@ public static void validateCreateOidcCredential(@Nullable final char[] password) public static void validateBeforeUse(final MongoCredential credential) { String userName = credential.getUserName(); Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); - Object requestCallback = credential.getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); - Object refreshCallback = credential.getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); + Object requestCallback = credential.getMechanismProperty(OIDC_CALLBACK_KEY, null); if (providerName == null) { // callback if (requestCallback == null) { throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + " or " - + REQUEST_TOKEN_CALLBACK_KEY + " must be specified"); + + OIDC_CALLBACK_KEY + " must be specified"); } } else { // automatic @@ -639,10 +413,7 @@ public static void validateBeforeUse(final MongoCredential credential) { throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified"); } if (requestCallback != null) { - throw new IllegalArgumentException(REQUEST_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); - } - if (refreshCallback != null) { - throw new IllegalArgumentException(REFRESH_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + throw new IllegalArgumentException(OIDC_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); } } } @@ -651,81 +422,29 @@ public static void validateBeforeUse(final MongoCredential credential) { @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) static class OidcRequestContextImpl implements OidcRequestContext { - private final IdpInfo idpInfo; private final Duration timeout; - OidcRequestContextImpl(final IdpInfo idpInfo, final Duration timeout) { - this.idpInfo = assertNotNull(idpInfo); + OidcRequestContextImpl(final Duration timeout) { this.timeout = assertNotNull(timeout); } - @Override - public IdpInfo getIdpInfo() { - return idpInfo; - } - @Override public Duration getTimeout() { return timeout; } - } - - @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) - static final class OidcRefreshContextImpl extends OidcRequestContextImpl - implements OidcRefreshContext { - private final String refreshToken; - - OidcRefreshContextImpl(final IdpInfo idpInfo, final String refreshToken, - final Duration timeout) { - super(idpInfo, timeout); - this.refreshToken = assertNotNull(refreshToken); - } - - @Override - public String getRefreshToken() { - return refreshToken; - } - } - - @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) - static final class IdpInfoImpl implements IdpInfo { - private final String issuer; - private final String clientId; - - private final List requestScopes; - - IdpInfoImpl(final String issuer, final String clientId, @Nullable final List requestScopes) { - this.issuer = assertNotNull(issuer); - this.clientId = assertNotNull(clientId); - this.requestScopes = requestScopes == null - ? Collections.emptyList() - : Collections.unmodifiableList(requestScopes); - } - - @Override - public String getIssuer() { - return issuer; - } - - @Override - public String getClientId() { - return clientId; - } @Override - public List getRequestScopes() { - return requestScopes; + public int getVersion() { + return CALLBACK_API_VERSION_NUMBER; } } /** - * Represents what was sent in the last request to the MongoDB server. + * What was sent in the last request by this connection to the server. */ private enum FallbackState { INITIAL, PHASE_1_CACHED_TOKEN, - PHASE_2_REFRESH_CALLBACK_TOKEN, - PHASE_3A_PRINCIPAL, - PHASE_3B_REQUEST_CALLBACK_TOKEN + PHASE_2_CALLBACK_TOKEN } } diff --git a/driver-core/src/test/resources/auth/legacy/connection-string.json b/driver-core/src/test/resources/auth/legacy/connection-string.json index 1d69685df10..f8521be9d19 100644 --- a/driver-core/src/test/resources/auth/legacy/connection-string.json +++ b/driver-core/src/test/resources/auth/legacy/connection-string.json @@ -446,68 +446,7 @@ } }, { - "description": "should recognise the mechanism and request callback (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], - "valid": true, - "credential": { - "username": null, - "password": null, - "source": "$external", - "mechanism": "MONGODB-OIDC", - "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true - } - } - }, - { - "description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external", - "callback": ["oidcRequest"], - "valid": true, - "credential": { - "username": null, - "password": null, - "source": "$external", - "mechanism": "MONGODB-OIDC", - "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true - } - } - }, - { - "description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest", "oidcRefresh"], - "valid": true, - "credential": { - "username": null, - "password": null, - "source": "$external", - "mechanism": "MONGODB-OIDC", - "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true, - "REFRESH_TOKEN_CALLBACK": true - } - } - }, - { - "description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)", - "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], - "valid": true, - "credential": { - "username": "principalName", - "password": null, - "source": "$external", - "mechanism": "MONGODB-OIDC", - "mechanism_properties": { - "REQUEST_TOKEN_CALLBACK": true - } - } - }, - { - "description": "should recognise the mechanism with aws device (MONGODB-OIDC)", + "description": "should recognise the mechanism with aws provider (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", "valid": true, "credential": { @@ -521,7 +460,7 @@ } }, { - "description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)", + "description": "should recognise the mechanism when auth source is explicitly specified and with provider (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws", "valid": true, "credential": { @@ -535,51 +474,29 @@ } }, { - "description": "should throw an exception if username and password are specified (MONGODB-OIDC)", - "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRequest"], + "description": "should throw an exception if supplied a password (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", "valid": false, "credential": null }, { - "description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)", - "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:gcp", + "description": "should throw an exception if username is specified for aws (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:aws", "valid": false, "credential": null }, { - "description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:unexisted", + "description": "should throw an exception if specified provider is not supported (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:invalid", "valid": false, "credential": null }, { - "description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)", + "description": "should throw an exception if neither provider nor callbacks specified (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", "valid": false, "credential": null }, - { - "description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", - "callback": ["oidcRefresh"], - "valid": false, - "credential": null - }, - { - "description": "should throw an exception if provider name and request callback are specified", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", - "callback": ["oidcRequest"], - "valid": false, - "credential": null - }, - { - "description": "should throw an exception if provider name and refresh callback are specified", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", - "callback": ["oidcRefresh"], - "valid": false, - "credential": null - }, { "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", @@ -587,4 +504,4 @@ "credential": null } ] -} \ No newline at end of file +} diff --git a/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json b/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json new file mode 100644 index 00000000000..7287c2486f0 --- /dev/null +++ b/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json @@ -0,0 +1,428 @@ +{ + "description": "MONGODB-OIDC authentication with retry disabled", + "schemaVersion": "1.19", + "runOnRequirements": [ + { + "minServerVersion": "7.0", + "auth": true, + "authMechanism": "MONGODB-OIDC" + } + ], + "createEntities": [ + { + "client": { + "id": "failPointClient", + "useMultipleMongoses": false + } + }, + { + "client": { + "id": "client0", + "uriOptions": { + "authMechanism": "MONGODB-OIDC", + "authMechanismProperties": { + "$$placeholder": 1 + }, + "retryReads": false, + "retryWrites": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "test" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "test", + "documents": [ + + ] + } + ], + "tests": [ + { + "description": "A read operation should succeed", + "operations": [ + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + } + }, + "expectResult": [ + + ] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "A write operation should succeed", + "operations": [ + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Read commands should reauthenticate and retry when a ReauthenticationRequired error happens", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "object": "collection0", + "arguments": { + "filter": { + } + }, + "expectResult": [ + + ] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": { + } + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write commands should reauthenticate and retry when a ReauthenticationRequired error happens", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Handshake with cached token should use speculative authentication", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "closeConnection": true + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + }, + "expectError": { + "isClientError": true + } + }, + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "saslStart" + ], + "errorCode": 20 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + }, + { + "description": "Handshake without cached token should not use speculative authentication", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "failPointClient", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": "alwaysOn", + "data": { + "failCommands": [ + "saslStart" + ], + "errorCode": 20 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + }, + "expectError": { + "errorCode": 20 + } + } + ] + } + ] +} \ No newline at end of file diff --git a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json deleted file mode 100644 index c99ebc6ece2..00000000000 --- a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json +++ /dev/null @@ -1,191 +0,0 @@ -{ - "description": "reauthenticate_with_retry", - "schemaVersion": "1.12", - "runOnRequirements": [ - { - "minServerVersion": "6.3", - "auth": true - } - ], - "createEntities": [ - { - "client": { - "id": "client0", - "uriOptions": { - "retryReads": true, - "retryWrites": true - }, - "observeEvents": [ - "commandStartedEvent", - "commandSucceededEvent", - "commandFailedEvent" - ] - } - }, - { - "database": { - "id": "database0", - "client": "client0", - "databaseName": "db" - } - }, - { - "collection": { - "id": "collection0", - "database": "database0", - "collectionName": "collName" - } - } - ], - "initialData": [ - { - "collectionName": "collName", - "databaseName": "db", - "documents": [] - } - ], - "tests": [ - { - "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", - "operations": [ - { - "name": "failPoint", - "object": "testRunner", - "arguments": { - "client": "client0", - "failPoint": { - "configureFailPoint": "failCommand", - "mode": { - "times": 1 - }, - "data": { - "failCommands": [ - "find" - ], - "errorCode": 391 - } - } - } - }, - { - "name": "find", - "arguments": { - "filter": {} - }, - "object": "collection0", - "expectResult": [] - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "find": "collName", - "filter": {} - } - } - }, - { - "commandFailedEvent": { - "commandName": "find" - } - }, - { - "commandStartedEvent": { - "command": { - "find": "collName", - "filter": {} - } - } - }, - { - "commandSucceededEvent": { - "commandName": "find" - } - } - ] - } - ] - }, - { - "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", - "operations": [ - { - "name": "failPoint", - "object": "testRunner", - "arguments": { - "client": "client0", - "failPoint": { - "configureFailPoint": "failCommand", - "mode": { - "times": 1 - }, - "data": { - "failCommands": [ - "insert" - ], - "errorCode": 391 - } - } - } - }, - { - "name": "insertOne", - "object": "collection0", - "arguments": { - "document": { - "_id": 1, - "x": 1 - } - } - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "insert": "collName", - "documents": [ - { - "_id": 1, - "x": 1 - } - ] - } - } - }, - { - "commandFailedEvent": { - "commandName": "insert" - } - }, - { - "commandStartedEvent": { - "command": { - "insert": "collName", - "documents": [ - { - "_id": 1, - "x": 1 - } - ] - } - } - }, - { - "commandSucceededEvent": { - "commandName": "insert" - } - } - ] - } - ] - } - ] -} \ No newline at end of file diff --git a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json deleted file mode 100644 index 799057bf74f..00000000000 --- a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json +++ /dev/null @@ -1,191 +0,0 @@ -{ - "description": "reauthenticate_without_retry", - "schemaVersion": "1.12", - "runOnRequirements": [ - { - "minServerVersion": "6.3", - "auth": true - } - ], - "createEntities": [ - { - "client": { - "id": "client0", - "uriOptions": { - "retryReads": false, - "retryWrites": false - }, - "observeEvents": [ - "commandStartedEvent", - "commandSucceededEvent", - "commandFailedEvent" - ] - } - }, - { - "database": { - "id": "database0", - "client": "client0", - "databaseName": "db" - } - }, - { - "collection": { - "id": "collection0", - "database": "database0", - "collectionName": "collName" - } - } - ], - "initialData": [ - { - "collectionName": "collName", - "databaseName": "db", - "documents": [] - } - ], - "tests": [ - { - "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", - "operations": [ - { - "name": "failPoint", - "object": "testRunner", - "arguments": { - "client": "client0", - "failPoint": { - "configureFailPoint": "failCommand", - "mode": { - "times": 1 - }, - "data": { - "failCommands": [ - "find" - ], - "errorCode": 391 - } - } - } - }, - { - "name": "find", - "arguments": { - "filter": {} - }, - "object": "collection0", - "expectResult": [] - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "find": "collName", - "filter": {} - } - } - }, - { - "commandFailedEvent": { - "commandName": "find" - } - }, - { - "commandStartedEvent": { - "command": { - "find": "collName", - "filter": {} - } - } - }, - { - "commandSucceededEvent": { - "commandName": "find" - } - } - ] - } - ] - }, - { - "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", - "operations": [ - { - "name": "failPoint", - "object": "testRunner", - "arguments": { - "client": "client0", - "failPoint": { - "configureFailPoint": "failCommand", - "mode": { - "times": 1 - }, - "data": { - "failCommands": [ - "insert" - ], - "errorCode": 391 - } - } - } - }, - { - "name": "insertOne", - "object": "collection0", - "arguments": { - "document": { - "_id": 1, - "x": 1 - } - } - } - ], - "expectEvents": [ - { - "client": "client0", - "events": [ - { - "commandStartedEvent": { - "command": { - "insert": "collName", - "documents": [ - { - "_id": 1, - "x": 1 - } - ] - } - } - }, - { - "commandFailedEvent": { - "commandName": "insert" - } - }, - { - "commandStartedEvent": { - "command": { - "insert": "collName", - "documents": [ - { - "_id": 1, - "x": 1 - } - ] - } - } - }, - { - "commandSucceededEvent": { - "commandName": "insert" - } - } - ] - } - ] - } - ] -} \ No newline at end of file diff --git a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java index 7f4acab857d..4da83dc7d4f 100644 --- a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java +++ b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java @@ -37,8 +37,7 @@ import java.util.List; import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; -import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; -import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; // See https://github.com/mongodb/specifications/tree/master/source/auth/legacy/tests @RunWith(Parameterized.class) @@ -119,12 +118,8 @@ private MongoCredential getMongoCredential() { String string = ((BsonString) v).getValue(); if ("oidcRequest".equals(string)) { credential = credential.withMechanismProperty( - REQUEST_TOKEN_CALLBACK_KEY, + OIDC_CALLBACK_KEY, (MongoCredential.OidcRequestCallback) (context) -> null); - } else if ("oidcRefresh".equals(string)) { - credential = credential.withMechanismProperty( - REFRESH_TOKEN_CALLBACK_KEY, - (MongoCredential.OidcRefreshCallback) (context) -> null); } else { fail("Unsupported callback: " + string); } @@ -180,14 +175,10 @@ private void assertMechanismProperties(final MongoCredential credential) { } } else if ((document.get(key).isBoolean())) { boolean expectedValue = document.getBoolean(key).getValue(); - if (REQUEST_TOKEN_CALLBACK_KEY.equals(key)) { + if (OIDC_CALLBACK_KEY.equals(key)) { assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRequestCallback); return; } - if (REFRESH_TOKEN_CALLBACK_KEY.equals(key)) { - assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRefreshCallback); - return; - } assertNotNull(actualMechanismProperty); assertEquals(expectedValue, actualMechanismProperty); } else { diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java index b18825e89a8..276dc9b68a9 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java @@ -42,10 +42,9 @@ public void testNonblockingCallbacks() { delayNextFind(); int simulatedDelayMs = 100; - TestCallback requestCallback = createCallback().setExpired().setDelayMs(simulatedDelayMs); - TestCallback refreshCallback = createCallback().setDelayMs(simulatedDelayMs); + TestCallback requestCallback = createCallback().setDelayMs(simulatedDelayMs); - MongoClientSettings clientSettings = createSettings(OIDC_URL, requestCallback, refreshCallback); + MongoClientSettings clientSettings = createSettings(getOidcUri(), requestCallback); try (com.mongodb.reactivestreams.client.MongoClient client = MongoClients.create(clientSettings)) { executeAll(2, () -> { @@ -64,7 +63,6 @@ public void testNonblockingCallbacks() { // ensure both callbacks have been tested assertEquals(1, requestCallback.getInvocations()); - assertEquals(1, refreshCallback.getInvocations()); } } } diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java index 773addf8767..5f42066aada 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java @@ -16,14 +16,17 @@ package com.mongodb.client.unified; +import com.mongodb.AuthenticationMechanism; import com.mongodb.ClientEncryptionSettings; import com.mongodb.ClientSessionOptions; import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCredential; import com.mongodb.ReadConcern; import com.mongodb.ReadConcernLevel; import com.mongodb.ReadPreference; import com.mongodb.ServerApi; import com.mongodb.ServerApiVersion; +import com.mongodb.internal.connection.OidcAuthenticator; import com.mongodb.event.TestServerMonitorListener; import com.mongodb.internal.connection.ServerMonitoringModeUtil; import com.mongodb.internal.connection.TestClusterListener; @@ -70,9 +73,15 @@ import org.bson.BsonDouble; import org.bson.BsonInt32; import org.bson.BsonInt64; +import org.bson.BsonNumber; import org.bson.BsonString; import org.bson.BsonValue; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -98,6 +107,7 @@ import static com.mongodb.client.unified.UnifiedCrudHelper.asReadPreference; import static com.mongodb.client.unified.UnifiedCrudHelper.asWriteConcern; import static com.mongodb.internal.connection.AbstractConnectionPoolTest.waitForPoolAsyncWorkManagerStart; +import static java.lang.System.getenv; import static java.util.Arrays.asList; import static java.util.Collections.synchronizedList; import static org.junit.Assume.assumeTrue; @@ -518,6 +528,33 @@ private void initClient(final BsonDocument entity, final String id, clientSettingsBuilder.applyToServerSettings(builder -> builder.serverMonitoringMode( ServerMonitoringModeUtil.fromString(value.asString().getValue()))); break; + case "authMechanism": + if (value.equals(new BsonString(AuthenticationMechanism.MONGODB_OIDC.getMechanismName()))) { + clientSettingsBuilder.credential(MongoCredential.createOidcCredential(null)); + break; + } + throw new UnsupportedOperationException("Unsupported authMechanism: " + value); + case "authMechanismProperties": + MongoCredential credential = clientSettingsBuilder.build().getCredential(); + boolean isOidc = credential != null + && credential.getAuthenticationMechanism() == AuthenticationMechanism.MONGODB_OIDC; + boolean hasPlaceholder = value.equals(new BsonDocument("$$placeholder", new BsonInt32(1))); + if (isOidc && hasPlaceholder) { + clientSettingsBuilder.credential(credential.withMechanismProperty( + MongoCredential.OIDC_CALLBACK_KEY, + (MongoCredential.OidcRequestCallback) context -> { + Path path = Paths.get(getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE)); + String accessToken; + try { + accessToken = new String(Files.readAllBytes(path), StandardCharsets.UTF_8); + } catch (IOException e) { + throw new RuntimeException(e); + } + return new MongoCredential.RequestCallbackResult(accessToken); + })); + break; + } + throw new UnsupportedOperationException("Failure to apply authMechanismProperties: " + value); default: throw new UnsupportedOperationException("Unsupported uri option: " + key); } diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/ErrorMatcher.java b/driver-sync/src/test/functional/com/mongodb/client/unified/ErrorMatcher.java index e232a4c9688..7c0d340a9ad 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/ErrorMatcher.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/ErrorMatcher.java @@ -20,6 +20,7 @@ import com.mongodb.MongoClientException; import com.mongodb.MongoCommandException; import com.mongodb.MongoException; +import com.mongodb.MongoSecurityException; import com.mongodb.MongoExecutionTimeoutException; import com.mongodb.MongoServerException; import com.mongodb.MongoSocketException; @@ -76,12 +77,17 @@ void assertErrorsMatch(final BsonDocument expectedError, final Exception e) { valueMatcher.assertValuesMatch(expectedError.getDocument("errorResponse"), ((MongoCommandException) e).getResponse()); } if (expectedError.containsKey("errorCode")) { - assertTrue(context.getMessage("Exception must be of type MongoCommandException or MongoQueryException when checking" - + " for error codes"), - e instanceof MongoCommandException || e instanceof MongoWriteException); - int errorCode = (e instanceof MongoCommandException) - ? ((MongoCommandException) e).getErrorCode() - : ((MongoWriteException) e).getCode(); + Exception errorCodeException = e; + if (e instanceof MongoSecurityException && e.getCause() instanceof MongoCommandException) { + errorCodeException = (Exception) e.getCause(); + } + assertTrue(context.getMessage("Exception must be of type MongoCommandException or MongoWriteException when checking" + + " for error codes, but was " + e.getClass().getSimpleName()), + errorCodeException instanceof MongoCommandException + || errorCodeException instanceof MongoWriteException); + int errorCode = (errorCodeException instanceof MongoCommandException) + ? ((MongoCommandException) errorCodeException).getErrorCode() + : ((MongoWriteException) errorCodeException).getCode(); assertEquals(context.getMessage("Error codes must match"), expectedError.getNumber("errorCode").intValue(), errorCode); diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java b/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java index bf6c0dcda01..aa7a3f80a53 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java @@ -74,6 +74,15 @@ public static boolean runOnRequirementsMet(final BsonArray runOnRequirements, fi break requirementLoop; } break; + case "authMechanism": + boolean containsMechanism = getServerParameters() + .getArray("authenticationMechanisms") + .contains(curRequirement.getValue()); + if (!containsMechanism) { + requirementMet = false; + break requirementLoop; + } + break; case "serverParameters": BsonDocument serverParameters = getServerParameters(); for (Map.Entry curParameter: curRequirement.getValue().asDocument().entrySet()) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java index c1741bd5f33..62eac081d4e 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java @@ -210,7 +210,9 @@ public void setUp() { || schemaVersion.equals("1.14") || schemaVersion.equals("1.15") || schemaVersion.equals("1.16") - || schemaVersion.equals("1.17")); + || schemaVersion.equals("1.17") + || schemaVersion.equals("1.18") + || schemaVersion.equals("1.19")); if (runOnRequirements != null) { assumeTrue("Run-on requirements not met", runOnRequirementsMet(runOnRequirements, getMongoClientSettings(), getServerVersion())); diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index 368e1342e1f..66b6a305297 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -16,14 +16,13 @@ package com.mongodb.internal.connection; +import com.mongodb.ClusterFixture; import com.mongodb.ConnectionString; import com.mongodb.MongoClientSettings; import com.mongodb.MongoCommandException; import com.mongodb.MongoConfigurationException; import com.mongodb.MongoCredential; -import com.mongodb.MongoCredential.IdpResponse; -import com.mongodb.MongoCredential.OidcRefreshCallback; -import com.mongodb.MongoSecurityException; +import com.mongodb.MongoCredential.RequestCallbackResult; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.TestListener; @@ -39,48 +38,36 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.function.Executable; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.CsvSource; import org.opentest4j.AssertionFailedError; -import org.opentest4j.MultipleFailuresError; import java.io.IOException; +import java.lang.reflect.Field; import java.nio.charset.StandardCharsets; import java.nio.file.Files; -import java.nio.file.NoSuchFileException; import java.nio.file.Path; import java.nio.file.Paths; import java.time.Duration; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Random; -import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Collectors; -import java.util.stream.Stream; -import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; -import static com.mongodb.MongoCredential.IdpInfo; -import static com.mongodb.MongoCredential.OidcRefreshContext; import static com.mongodb.MongoCredential.OidcRequestCallback; import static com.mongodb.MongoCredential.OidcRequestContext; -import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; -import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; -import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; -import static com.mongodb.MongoCredential.createOidcCredential; -import static com.mongodb.client.TestHelper.setEnvironmentVariable; +import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; import static java.lang.System.getenv; import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; import static util.ThreadTestHelpers.executeAll; - /** * See * Prose Tests. @@ -91,28 +78,25 @@ public static boolean oidcTestsEnabled() { return Boolean.parseBoolean(getenv().get("OIDC_TESTS_ENABLED")); } - private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + private String appName; - public static final String TOKEN_DIRECTORY = "/tmp/tokens/"; // TODO-OIDC + protected static String getOidcUri() { + ConnectionString cs = ClusterFixture.getConnectionString(); + // remove username and password + return "mongodb+srv://" + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; + } - protected static final String OIDC_URL = "mongodb://localhost/?authMechanism=MONGODB-OIDC"; - private static final String AWS_OIDC_URL = - "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws"; - private String appName; + private static String getAwsOidcUri() { + return getOidcUri() + "&authMechanismProperties=PROVIDER_NAME:aws"; + } protected MongoClient createMongoClient(final MongoClientSettings settings) { return MongoClients.create(settings); } - protected void setOidcFile(final String file) { - setEnvironmentVariable(AWS_WEB_IDENTITY_TOKEN_FILE, TOKEN_DIRECTORY + file); - } - @BeforeEach public void beforeEach() { assumeTrue(oidcTestsEnabled()); - // In each test, clearing the cache is not required, since there is no global cache - setOidcFile("test_user1"); InternalStreamConnection.setRecordEverything(true); this.appName = this.getClass().getSimpleName() + "-" + new Random().nextInt(Integer.MAX_VALUE); } @@ -122,196 +106,77 @@ public void afterEach() { InternalStreamConnection.setRecordEverything(false); } - @ParameterizedTest - @CsvSource(delimiter = '#', value = { - // 1.1 to 1.5: - "test1p1 # test_user1 # " + OIDC_URL, - "test1p2 # test_user1 # mongodb://test_user1@localhost/?authMechanism=MONGODB-OIDC", - "test1p3 # test_user1 # mongodb://test_user1@localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", - "test1p4 # test_user2 # mongodb://test_user2@localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", - "test1p5 # invalid # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", - }) - public void test1CallbackDrivenAuth(final String name, final String file, final String url) { - boolean shouldPass = !file.equals("invalid"); - setOidcFile(file); - // #. Create a request callback that returns a valid token. - OidcRequestCallback onRequest = createCallback(); - // #. Create a client with a URL of the form ... and the OIDC request callback. - MongoClientSettings clientSettings = createSettings(url, onRequest, null); - // #. Perform a find operation that succeeds / fails - if (shouldPass) { - performFind(clientSettings); - } else { - performFind( - clientSettings, - MongoCommandException.class, - "Command failed with error 18 (AuthenticationFailed)"); - } - } - - @ParameterizedTest - @CsvSource(delimiter = '#', value = { - // 1.6, both variants: - "'' # " + OIDC_URL, - "example.com # mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com", - }) - public void test1p6CallbackDrivenAuthAllowedHostsBlocked(final String allowedHosts, final String url) { - // Create a client that uses the OIDC url and a request callback, and an ALLOWED_HOSTS that contains... - List allowedHostsList = asList(allowedHosts.split(",")); - MongoClientSettings settings = createSettings(url, createCallback(), null, allowedHostsList, null); - // #. Assert that a find operation fails with a client-side error. - performFind(settings, MongoSecurityException.class, ""); - } - @Test - public void test1p7LockAvoidsExtraCallbackCalls() { - proveThatConcurrentCallbacksThrow(); - // The test requires that two operations are attempted concurrently. - // The delay on the next find should cause the initial request to delay - // and the ensuing refresh to block, rather than entering onRefresh. - // After blocking, this ensuing refresh thread will enter onRefresh. - AtomicInteger concurrent = new AtomicInteger(); - TestCallback onRequest = createCallback().setExpired().setConcurrentTracker(concurrent); - TestCallback onRefresh = createCallback().setConcurrentTracker(concurrent); - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - delayNextFind(); // cause both callbacks to be called - executeAll(2, () -> performFind(mongoClient)); - assertEquals(1, onRequest.getInvocations()); - assertEquals(1, onRefresh.getInvocations()); - } - } - - public void proveThatConcurrentCallbacksThrow() { - // ensure that, via delay, test callbacks throw when invoked concurrently - AtomicInteger c = new AtomicInteger(); - TestCallback request = createCallback().setConcurrentTracker(c).setDelayMs(5); - TestCallback refresh = createCallback().setConcurrentTracker(c); - IdpInfo serverInfo = new OidcAuthenticator.IdpInfoImpl("issuer", "clientId", asList()); - executeAll(() -> { - sleep(2); - assertThrows(RuntimeException.class, () -> { - refresh.onRefresh(new OidcAuthenticator.OidcRefreshContextImpl(serverInfo, "refToken", Duration.ofSeconds(1234))); - }); - }, () -> { - request.onRequest(new OidcAuthenticator.OidcRequestContextImpl(serverInfo, Duration.ofSeconds(1234))); - }); - } - - private void sleep(final long ms) { - try { - Thread.sleep(ms); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - } - - @ParameterizedTest - @CsvSource(delimiter = '#', value = { - // 2.1 to 2.3: - "test2p1 # test_user1 # " + AWS_OIDC_URL, - "test2p2 # test_user1 # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws&directConnection=true&readPreference=secondaryPreferred", - "test2p3 # test_user2 # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws&directConnection=true&readPreference=secondaryPreferred", - }) - public void test2AwsAutomaticAuth(final String name, final String file, final String url) { - setOidcFile(file); - // #. Create a client with a url of the form ... - MongoCredential credential = createOidcCredential(null) - .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); - MongoClientSettings clientSettings = MongoClientSettings.builder() - .applicationName(appName) - .credential(credential) - .applyConnectionString(new ConnectionString(url)) - .build(); - // #. Perform a find operation that succeeds. + public void test1p1CallbackIsCalledDuringAuth() { + // #. Create a ``MongoClient`` configured with an OIDC callback... + TestCallback onRequest = createCallback(); + MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); + // #. Perform a find operation that succeeds performFind(clientSettings); + assertEquals(1, onRequest.invocations.get()); } @Test - public void test2p4AllowedHostsIgnored() { - MongoClientSettings settings = createSettings( - AWS_OIDC_URL, null, null, Arrays.asList(), null); - performFind(settings); + public void test1p2CallbackCalledOnceForMultipleConnections() { + TestCallback onRequest = createCallback(); + MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + List threads = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + Thread t = new Thread(() -> performFind(mongoClient)); + t.setDaemon(true); + t.start(); + threads.add(t); + } + for (Thread t : threads) { + try { + t.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + } + assertEquals(1, onRequest.invocations.get()); } @Test - public void test3p1ValidCallbacks() { - String connectionString = "mongodb://test_user1@localhost/?authMechanism=MONGODB-OIDC"; - String expectedClientId = "0oadp0hpl7q3UIehP297"; - String expectedIssuer = "https://ebgxby0dw8.execute-api.us-west-1.amazonaws.com/default/mock-identity-config-oidc"; + public void test2p1ValidCallbackInputs() { + String connectionString = getOidcUri(); Duration expectedSeconds = Duration.ofMinutes(5); - TestCallback onRequest = createCallback().setExpired(); - TestCallback onRefresh = createCallback(); + TestCallback onRequest = createCallback(); // #. Verify that the request callback was called with the appropriate // inputs, including the timeout parameter if possible. - // #. Verify that the refresh callback was called with the appropriate - // inputs, including the timeout parameter if possible. OidcRequestCallback onRequest2 = (context) -> { - assertEquals(expectedClientId, context.getIdpInfo().getClientId()); - assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); - assertEquals(Arrays.asList(), context.getIdpInfo().getRequestScopes()); assertEquals(expectedSeconds, context.getTimeout()); return onRequest.onRequest(context); }; - OidcRefreshCallback onRefresh2 = (context) -> { - assertEquals(expectedClientId, context.getIdpInfo().getClientId()); - assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); - assertEquals(Arrays.asList(), context.getIdpInfo().getRequestScopes()); - assertEquals(expectedSeconds, context.getTimeout()); - assertEquals("refreshToken", context.getRefreshToken()); - return onRefresh.onRefresh(context); - }; - MongoClientSettings clientSettings = createSettings(connectionString, onRequest2, onRefresh2); + MongoClientSettings clientSettings = createSettings(connectionString, onRequest2); try (MongoClient mongoClient = createMongoClient(clientSettings)) { - delayNextFind(); // cause both callbacks to be called - executeAll(2, () -> performFind(mongoClient)); - // Ensure that both callbacks were called + performFind(mongoClient); + // callback was called assertEquals(1, onRequest.getInvocations()); - assertEquals(1, onRefresh.getInvocations()); } } @Test - public void test3p2RequestCallbackReturnsNull() { + public void test2p2RequestCallbackReturnsNull() { //noinspection ConstantConditions OidcRequestCallback onRequest = (context) -> null; - MongoClientSettings settings = this.createSettings(OIDC_URL, onRequest, null); + MongoClientSettings settings = this.createSettings(getOidcUri(), onRequest, null); performFind(settings, MongoConfigurationException.class, "Result of callback must not be null"); } @Test - public void test3p3RefreshCallbackReturnsNull() { - TestCallback onRequest = createCallback().setExpired().setDelayMs(100); - //noinspection ConstantConditions - OidcRefreshCallback onRefresh = (context) -> null; - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - delayNextFind(); // cause both callbacks to be called - try { - executeAll(2, () -> performFind(mongoClient)); - } catch (MultipleFailuresError actual) { - assertEquals(1, actual.getFailures().size()); - assertCause( - MongoConfigurationException.class, - "Result of callback must not be null", - actual.getFailures().get(0)); - } - assertEquals(1, onRequest.getInvocations()); - } - } - - @Test - public void test3p4RequestCallbackReturnsInvalidData() { + public void test2p3CallbackReturnsMissingData() { // #. Create a client with a request callback that returns data not // conforming to the OIDCRequestTokenResult with missing field(s). - // #. ... with extra field(s). - not possible OidcRequestCallback onRequest = (context) -> { //noinspection ConstantConditions - return new IdpResponse(null, null, null); + return new RequestCallbackResult(null); }; // we ensure that the error is propagated - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null); + MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); try (MongoClient mongoClient = createMongoClient(clientSettings)) { try { performFind(mongoClient); @@ -323,399 +188,100 @@ public void test3p4RequestCallbackReturnsInvalidData() { } @Test - public void test3p5RefreshCallbackReturnsInvalidData() { - TestCallback onRequest = createCallback().setExpired(); - OidcRefreshCallback onRefresh = (context) -> { - //noinspection ConstantConditions - return new IdpResponse(null, null, null); - }; - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - try { - executeAll(2, () -> performFind(mongoClient)); - } catch (MultipleFailuresError actual) { - assertEquals(1, actual.getFailures().size()); - assertCause( - IllegalArgumentException.class, - "accessToken can not be null", - actual.getFailures().get(0)); - } - assertEquals(1, onRequest.getInvocations()); - } - } - - // 3.6 Refresh Callback Returns Extra Data - not possible due to use of class - - @Test - public void test4p1CachedCredentialsCacheWithRefresh() { - // #. Create a new client with a request callback that gives credentials that expire in one minute. - TestCallback onRequest = createCallback().setExpired(); - TestCallback onRefresh = createCallback(); - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - // #. Create a new client with the same request callback and a refresh callback. - // Instead: - // 1. Delay the first find, causing the second find to authenticate a second connection - delayNextFind(); // cause both callbacks to be called - executeAll(2, () -> performFind(mongoClient)); - // #. Ensure that a find operation adds credentials to the cache. - // #. Ensure that a find operation results in a call to the refresh callback. - assertEquals(1, onRequest.getInvocations()); - assertEquals(1, onRefresh.getInvocations()); - // the refresh invocation will fail if the cached tokens are null - // so a success implies that credentials were present in the cache - } - } - - @Test - public void test4p2CachedCredentialsCacheWithNoRefresh() { - // #. Create a new client with a request callback that gives credentials that expire in one minute. - // #. Ensure that a find operation adds credentials to the cache. - // #. Create a new client with a request callback but no refresh callback. - // #. Ensure that a find operation results in a call to the request callback. - TestCallback onRequest = createCallback().setExpired(); - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - delayNextFind(); // cause both callbacks to be called - executeAll(2, () -> performFind(mongoClient)); - // test is the same as 4.1, but no onRefresh, and assert that the onRequest is called twice - assertEquals(2, onRequest.getInvocations()); - } - } - - // 4.3 Cache key includes callback - skipped: - // If the driver does not support using callback references or hashes as part of the cache key, skip this test. - - @Test - public void test4p4ErrorClearsCache() { - // #. Create a new client with a valid request callback that - // gives credentials that expire within 5 minutes and - // a refresh callback that gives invalid credentials. - - TestListener listener = new TestListener(); - ConcurrentLinkedQueue tokens = tokenQueue( - "test_user1", - "test_user1_expires", - "test_user1_expires", - "test_user1_1"); - TestCallback onRequest = createCallback() - .setExpired() - .setPathSupplier(() -> tokens.remove()) - .setEventListener(listener); - TestCallback onRefresh = createCallback() - .setPathSupplier(() -> tokens.remove()) - .setEventListener(listener); - - TestCommandListener commandListener = new TestCommandListener(listener); - - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - // #. Ensure that a find operation adds a new entry to the cache. - performFind(mongoClient); - assertEquals(Arrays.asList( - "isMaster started", - "isMaster succeeded", - "onRequest invoked", - "read access token: test_user1", - "saslContinue started", - "saslContinue succeeded", - "find started", - "find succeeded" - ), listener.getEventStrings()); - listener.clear(); - - // #. Ensure that a subsequent find operation results in a 391 error. - failCommand(391, 1, "find"); - // ensure that the operation entirely fails, after attempting both potential fallback callbacks - assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); - assertEquals(Arrays.asList( - "find started", - "find failed", - "onRefresh invoked", - "read access token: test_user1_expires", - "saslStart started", - "saslStart failed", - // falling back to principal request, onRequest callback. - "saslStart started", - "saslStart succeeded", - "onRequest invoked", - "read access token: test_user1_expires", - "saslContinue started", - "saslContinue failed" - ), listener.getEventStrings()); - listener.clear(); - - // #. Ensure that the cache value cleared. - failCommand(391, 1, "find"); - performFind(mongoClient); - assertEquals(Arrays.asList( - "find started", - "find failed", - // falling back to principal request, onRequest callback. - // this implies that the cache has been cleared during the - // preceding find operation. - "saslStart started", - "saslStart succeeded", - "onRequest invoked", - "read access token: test_user1_1", - "saslContinue started", - "saslContinue succeeded", - // auth has finished - "find started", - "find succeeded" - ), listener.getEventStrings()); - listener.clear(); + public void test2p4InvalidClientConfigurationWithCallback() { + String awsOidcUri = getAwsOidcUri(); + MongoClientSettings settings = createSettings( + awsOidcUri, createCallback(), null); + try { + performFind(settings); + fail(); + } catch (Exception e) { + assertCause(IllegalArgumentException.class, + "OIDC_CALLBACK must not be specified when PROVIDER_NAME is specified", e); } } - // not a prose test. @Test - public void testEventListenerMustNotLogReauthentication() { - InternalStreamConnection.setRecordEverything(false); - - TestListener listener = new TestListener(); - ConcurrentLinkedQueue tokens = tokenQueue( - "test_user1", - "test_user1_expires", - "test_user1_expires", - "test_user1_1"); - TestCallback onRequest = createCallback() - .setExpired() - .setPathSupplier(() -> tokens.remove()) - .setEventListener(listener); - TestCallback onRefresh = createCallback() - .setPathSupplier(() -> tokens.remove()) - .setEventListener(listener); - - TestCommandListener commandListener = new TestCommandListener(listener); - - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - performFind(mongoClient); - assertEquals(Arrays.asList( - "onRequest invoked", - "read access token: test_user1", - "find started", - "find succeeded" - ), listener.getEventStrings()); - listener.clear(); - - failCommand(391, 1, "find"); - assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); - assertEquals(Arrays.asList( - "find started", - "find failed", - "onRefresh invoked", - "read access token: test_user1_expires", - // falling back to principal request, onRequest callback - "onRequest invoked", - "read access token: test_user1_expires" - ), listener.getEventStrings()); - } - } + public void test3p1AuthFailsWithCachedToken() throws ExecutionException, InterruptedException, NoSuchFieldException, IllegalAccessException { + TestCallback onRequestWrapped = createCallback(); + CompletableFuture poisonToken = new CompletableFuture<>(); + OidcRequestCallback onRequest = (context) -> { + RequestCallbackResult result = onRequestWrapped.onRequest(context); + String accessToken = result.getAccessToken(); + if (!poisonToken.isDone()) { + poisonToken.complete(accessToken); + } + return result; + }; - @Test - public void test4p5AwsAutomaticWorkflowDoesNotUseCache() { - // #. Create a new client that uses the AWS automatic workflow. - // #. Ensure that a find operation does not add credentials to the cache. - setOidcFile("test_user1"); - MongoCredential credential = createOidcCredential(null) - .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); - ConnectionString connectionString = new ConnectionString(AWS_OIDC_URL); - MongoClientSettings clientSettings = MongoClientSettings.builder() - .applicationName(appName) - .credential(credential) - .applyConnectionString(connectionString) - .build(); + MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // populate cache performFind(mongoClient); - // This ensures that the next find failure results in a file (rather than cache) read - failCommand(391, 1, "find"); - setOidcFile("invalid_file"); - assertCause(NoSuchFileException.class, "invalid_file", () -> performFind(mongoClient)); - } - } - - @Test - public void test5SpeculativeAuthentication() { - // #. We can only test the successful case, by verifying that saslStart is not called. - // #. Create a client with a request callback that returns a valid token that will not expire soon. - TestListener listener = new TestListener(); - TestCallback onRequest = createCallback().setEventListener(listener); - TestCommandListener commandListener = new TestCommandListener(listener); - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null, null, commandListener); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - // instead of setting failpoints for saslStart, we inspect events - delayNextFind(); + assertEquals(1, onRequestWrapped.invocations.get()); + // Poison the *Client Cache* with an invalid access token. + // uses reflection + String poisonString = poisonToken.get(); + Field f = String.class.getDeclaredField("value"); + f.setAccessible(true); + byte[] poisonChars = (byte[]) f.get(poisonString); + poisonChars[0] = '~'; + poisonChars[1] = '~'; + + assertEquals(1, onRequestWrapped.invocations.get()); + + // cause another connection to be opened + delayNextFind(); // cause both callbacks to be called executeAll(2, () -> performFind(mongoClient)); - - List events = listener.getEventStrings(); - assertFalse(events.stream().anyMatch(e -> e.contains("saslStart"))); - // onRequest is 2-step, so we expect 2 continues - assertEquals(2, events.stream().filter(e -> e.contains("saslContinue started")).count()); - // confirm all commands are enabled - assertTrue(events.stream().anyMatch(e -> e.contains("isMaster started"))); - } - } - - // Not a prose test - @Test - public void testAutomaticAuthUsesSpeculative() { - TestListener listener = new TestListener(); - TestCommandListener commandListener = new TestCommandListener(listener); - - MongoClientSettings settings = createSettings( - AWS_OIDC_URL, null, null, Arrays.asList(), commandListener); - try (MongoClient mongoClient = createMongoClient(settings)) { - // we use a listener instead of a failpoint - performFind(mongoClient); - assertEquals(Arrays.asList( - "isMaster started", - "isMaster succeeded", - "find started", - "find succeeded" - ), listener.getEventStrings()); } + assertEquals(2, onRequestWrapped.invocations.get()); } @Test - public void test6p1ReauthenticationSucceeds() { - // #. Create request and refresh callbacks that return valid credentials that will not expire soon. - TestListener listener = new TestListener(); - TestCallback onRequest = createCallback().setEventListener(listener); - TestCallback onRefresh = createCallback().setEventListener(listener); - - // #. Create a client with the callbacks and an event listener capable of listening for SASL commands. - TestCommandListener commandListener = new TestCommandListener(listener); - - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + public void test3p2AuthFailsWithoutCachedToken() { + MongoClientSettings clientSettings = createSettings(getOidcUri(), + (x) -> new RequestCallbackResult("invalid_token"), null); try (MongoClient mongoClient = createMongoClient(clientSettings)) { - - // #. Perform a find operation that succeeds. - performFind(mongoClient); - - // #. Assert that the refresh callback has not been called. - assertEquals(0, onRefresh.getInvocations()); - - assertEquals(Arrays.asList( - // speculative: - "isMaster started", - "isMaster succeeded", - // onRequest: - "onRequest invoked", - "read access token: test_user1", - // jwt from onRequest: - "saslContinue started", - "saslContinue succeeded", - // ensuing find: - "find started", - "find succeeded" - ), listener.getEventStrings()); - - // #. Clear the listener state if possible. - commandListener.reset(); - listener.clear(); - - // #. Force a reauthenication using a failCommand - failCommand(391, 1, "find"); - - // #. Perform another find operation that succeeds. - performFind(mongoClient); - - // #. Assert that the ordering of command started events is: find, find. - // #. Assert that the ordering of command succeeded events is: find. - // #. Assert that a find operation failed once during the command execution. - assertEquals(Arrays.asList( - "find started", - "find failed", - // find has triggered 391, and cleared the access token; fall back to refresh: - "onRefresh invoked", - "read access token: test_user1", - "saslStart started", - "saslStart succeeded", - // find retry succeeds: - "find started", - "find succeeded" - ), listener.getEventStrings()); - - // #. Assert that the refresh callback has been called once, if possible. - assertEquals(1, onRefresh.getInvocations()); + try { + performFind(mongoClient); + fail(); + } catch (Exception e) { + assertCause(MongoCommandException.class, + "Command failed with error 18 (AuthenticationFailed):", e); + } } } - @NotNull - private ConcurrentLinkedQueue tokenQueue(final String... queue) { - return Stream - .of(queue) - .map(v -> TOKEN_DIRECTORY + v) - .collect(Collectors.toCollection(ConcurrentLinkedQueue::new)); - } @Test - public void test6p2ReauthenticationRetriesAndSucceedsWithCache() { - // #. Create request and refresh callbacks that return valid credentials that will not expire soon. + public void test4p1Reauthentication() { TestCallback onRequest = createCallback(); - TestCallback onRefresh = createCallback(); - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest); try (MongoClient mongoClient = createMongoClient(clientSettings)) { - // #. Perform a find operation that succeeds. - performFind(mongoClient); - // #. Force a reauthenication using a failCommand failCommand(391, 1, "find"); // #. Perform a find operation that succeeds. performFind(mongoClient); } - } - - // 6.3 Retries and Fails with no Cache - // Appears to be untestable, since it requires 391 failure on jwt (may be fixed in future spec) - - @Test - public void test6p4SeparateConnectionsAvoidExtraCallbackCalls() { - ConcurrentLinkedQueue tokens = tokenQueue( - "test_user1", - "test_user1_1"); - TestCallback onRequest = createCallback().setPathSupplier(() -> tokens.remove()); - TestCallback onRefresh = createCallback().setPathSupplier(() -> tokens.remove()); - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - // #. Peform a find operation on each ... that succeeds. - delayNextFind(); - executeAll(2, () -> performFind(mongoClient)); - // #. Ensure that the request callback has been called once and the refresh callback has not been called. - assertEquals(1, onRequest.getInvocations()); - assertEquals(0, onRefresh.getInvocations()); - - failCommand(391, 2, "find"); - executeAll(2, () -> performFind(mongoClient)); - - // #. Ensure that the request callback has been called once and the refresh callback has been called once. - assertEquals(1, onRequest.getInvocations()); - assertEquals(1, onRefresh.getInvocations()); - } + assertEquals(2, onRequest.invocations.get()); } public MongoClientSettings createSettings( final String connectionString, - @Nullable final OidcRequestCallback onRequest, - @Nullable final OidcRefreshCallback onRefresh) { - return createSettings(connectionString, onRequest, onRefresh, null, null); + @Nullable final OidcRequestCallback onRequest) { + return createSettings(connectionString, onRequest, null); } private MongoClientSettings createSettings( final String connectionString, @Nullable final OidcRequestCallback onRequest, - @Nullable final OidcRefreshCallback onRefresh, - @Nullable final List allowedHosts, @Nullable final CommandListener commandListener) { ConnectionString cs = new ConnectionString(connectionString); MongoCredential credential = cs.getCredential() - .withMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, onRequest) - .withMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, onRefresh) - .withMechanismProperty(ALLOWED_HOSTS_KEY, allowedHosts); + .withMechanismProperty(OIDC_CALLBACK_KEY, onRequest); MongoClientSettings.Builder builder = MongoClientSettings.builder() .applicationName(appName) .applyConnectionString(cs) + .retryReads(false) .credential(credential); if (commandListener != null) { builder.addCommandListener(commandListener); @@ -767,7 +333,7 @@ private static void assertCause( } protected void delayNextFind() { - try (MongoClient client = createMongoClient(createSettings(AWS_OIDC_URL, null, null))) { + try (MongoClient client = createMongoClient(createSettings(getAwsOidcUri(), null, null))) { BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(1))) .append("data", new BsonDocument() @@ -781,7 +347,7 @@ protected void delayNextFind() { protected void failCommand(final int code, final int times, final String... commands) { try (MongoClient mongoClient = createMongoClient(createSettings( - AWS_OIDC_URL, null, null))) { + getAwsOidcUri(), null, null))) { List list = Arrays.stream(commands).map(c -> new BsonString(c)).collect(Collectors.toList()); BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(times))) @@ -793,11 +359,9 @@ protected void failCommand(final int code, final int times, final String... comm } } - public static class TestCallback implements OidcRequestCallback, OidcRefreshCallback { + public static class TestCallback implements OidcRequestCallback { private final AtomicInteger invocations = new AtomicInteger(); @Nullable - private final Integer expiresInSeconds; - @Nullable private final Integer delayInMilliseconds; @Nullable private final AtomicInteger concurrentTracker; @@ -807,16 +371,14 @@ public static class TestCallback implements OidcRequestCallback, OidcRefreshCall private final Supplier pathSupplier; public TestCallback() { - this(60 * 60, null, new AtomicInteger(), null, null); + this(null, new AtomicInteger(), null, null); } public TestCallback( - @Nullable final Integer expiresInSeconds, @Nullable final Integer delayInMilliseconds, @Nullable final AtomicInteger concurrentTracker, @Nullable final TestListener testListener, @Nullable final Supplier pathSupplier) { - this.expiresInSeconds = expiresInSeconds; this.delayInMilliseconds = delayInMilliseconds; this.concurrentTracker = concurrentTracker; this.testListener = testListener; @@ -828,26 +390,15 @@ public int getInvocations() { } @Override - public IdpResponse onRequest(final OidcRequestContext context) { + public RequestCallbackResult onRequest(final OidcRequestContext context) { if (testListener != null) { testListener.add("onRequest invoked"); } return callback(); } - @Override - public IdpResponse onRefresh(final OidcRefreshContext context) { - if (context.getRefreshToken() == null) { - throw new IllegalArgumentException("refreshToken was null"); - } - if (testListener != null) { - testListener.add("onRefresh invoked"); - } - return callback(); - } - @NotNull - private IdpResponse callback() { + private RequestCallbackResult callback() { if (concurrentTracker != null) { if (concurrentTracker.get() > 0) { throw new RuntimeException("Callbacks should not be invoked by multiple threads."); @@ -857,7 +408,7 @@ private IdpResponse callback() { try { invocations.incrementAndGet(); Path path = Paths.get(pathSupplier == null - ? getenv(AWS_WEB_IDENTITY_TOKEN_FILE) + ? getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE) : pathSupplier.get()); String accessToken; try { @@ -866,14 +417,10 @@ private IdpResponse callback() { } catch (IOException | InterruptedException e) { throw new RuntimeException(e); } - String refreshToken = "refreshToken"; if (testListener != null) { testListener.add("read access token: " + path.getFileName()); } - return new IdpResponse( - accessToken, - expiresInSeconds, - refreshToken); + return new RequestCallbackResult(accessToken); } finally { if (concurrentTracker != null) { concurrentTracker.decrementAndGet(); @@ -887,18 +434,8 @@ private void simulateDelay() throws InterruptedException { } } - public TestCallback setExpiresInSeconds(final Integer expiresInSeconds) { - return new TestCallback( - expiresInSeconds, - this.delayInMilliseconds, - this.concurrentTracker, - this.testListener, - this.pathSupplier); - } - public TestCallback setDelayMs(final int milliseconds) { return new TestCallback( - this.expiresInSeconds, milliseconds, this.concurrentTracker, this.testListener, @@ -907,7 +444,6 @@ public TestCallback setDelayMs(final int milliseconds) { public TestCallback setConcurrentTracker(final AtomicInteger c) { return new TestCallback( - this.expiresInSeconds, this.delayInMilliseconds, c, this.testListener, @@ -916,7 +452,6 @@ public TestCallback setConcurrentTracker(final AtomicInteger c) { public TestCallback setEventListener(final TestListener testListener) { return new TestCallback( - this.expiresInSeconds, this.delayInMilliseconds, this.concurrentTracker, testListener, @@ -925,16 +460,11 @@ public TestCallback setEventListener(final TestListener testListener) { public TestCallback setPathSupplier(final Supplier pathSupplier) { return new TestCallback( - this.expiresInSeconds, this.delayInMilliseconds, this.concurrentTracker, this.testListener, pathSupplier); } - - public TestCallback setExpired() { - return this.setExpiresInSeconds(60); - } } public TestCallback createCallback() { From 8b80055a52598d0a75f6d102484bb8c5524849a6 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 4 Mar 2024 08:27:37 -0700 Subject: [PATCH 4/6] Add Human OIDC Workflow (#1316) * Add human workflow * Apply suggestions from code review Co-authored-by: Valentin Kovalenko * Add expiresIn, address PR comments * PR fixes * Fix compilation --------- Co-authored-by: Valentin Kovalenko --- .../main/com/mongodb/ConnectionString.java | 11 + .../src/main/com/mongodb/MongoCredential.java | 135 +++++- .../connection/OidcAuthenticator.java | 312 ++++++++++-- .../com/mongodb/AuthConnectionStringTest.java | 4 +- .../com/mongodb/client/unified/Entities.java | 6 +- .../OidcAuthenticationProseTests.java | 455 +++++++++++++++++- 6 files changed, 845 insertions(+), 78 deletions(-) diff --git a/driver-core/src/main/com/mongodb/ConnectionString.java b/driver-core/src/main/com/mongodb/ConnectionString.java index c5197b8b7d4..8bb802e9e70 100644 --- a/driver-core/src/main/com/mongodb/ConnectionString.java +++ b/driver-core/src/main/com/mongodb/ConnectionString.java @@ -47,7 +47,10 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateCreateOidcCredential; import static java.lang.String.format; import static java.util.Arrays.asList; @@ -282,6 +285,9 @@ public class ConnectionString { private static final Set ALLOWED_OPTIONS_IN_TXT_RECORD = new HashSet<>(asList("authsource", "replicaset", "loadbalanced")); private static final Logger LOGGER = Loggers.getLogger("uri"); + private static final List MECHANISM_KEYS_DISALLOWED_IN_CONNECTION_STRING = Stream.of(ALLOWED_HOSTS_KEY) + .map(k -> k.toLowerCase()) + .collect(Collectors.toList()); private final MongoCredential credential; private final boolean isSrvProtocol; @@ -917,6 +923,11 @@ private MongoCredential createCredentials(final Map> option } String key = mechanismPropertyKeyValue[0].trim().toLowerCase(); String value = mechanismPropertyKeyValue[1].trim(); + if (MECHANISM_KEYS_DISALLOWED_IN_CONNECTION_STRING.contains(key)) { + throw new IllegalArgumentException(format("The connection string contains disallowed mechanism properties. " + + "'%s' must be set on the credential programmatically.", key)); + } + if (key.equals("canonicalize_host_name")) { credential = credential.withMechanismProperty(key, Boolean.valueOf(value)); } else { diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index 4c10e1f640c..295803e55a4 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -25,6 +25,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -187,7 +188,8 @@ public final class MongoCredential { * The provider name. The value must be a string. *

* If this is provided, - * {@link MongoCredential#OIDC_CALLBACK_KEY} + * {@link MongoCredential#OIDC_CALLBACK_KEY} and + * {@link MongoCredential#OIDC_HUMAN_CALLBACK_KEY} * must not be provided. * * @see #createOidcCredential(String) @@ -197,10 +199,13 @@ public final class MongoCredential { /** * This callback is invoked when the OIDC-based authenticator requests - * tokens from the identity provider. The type of the value must be - * {@link OidcRequestCallback}. + * a token. The type of the value must be {@link OidcCallback}. + * {@link IdpInfo} will not be supplied to the callback, + * and a {@linkplain OidcCallbackResult#getRefreshToken() refresh token} + * must not be returned by the callback. *

* If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * and {@link MongoCredential#OIDC_HUMAN_CALLBACK_KEY} * must not be provided. * * @see #createOidcCredential(String) @@ -208,6 +213,46 @@ public final class MongoCredential { */ public static final String OIDC_CALLBACK_KEY = "OIDC_CALLBACK"; + /** + * This callback is invoked when the OIDC-based authenticator requests + * a token from the identity provider (IDP) using the IDP information + * from the MongoDB server. The type of the value must be + * {@link OidcCallback}. + *

+ * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * and {@link MongoCredential#OIDC_CALLBACK_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String OIDC_HUMAN_CALLBACK_KEY = "OIDC_HUMAN_CALLBACK"; + + + /** + * Mechanism key for a list of allowed hostnames or ip-addresses for MongoDB connections. Ports must be excluded. + * The hostnames may include a leading "*." wildcard, which allows for matching (potentially nested) subdomains. + * When MONGODB-OIDC authentication is attempted against a hostname that does not match any of list of allowed hosts + * the driver will raise an error. The type of the value must be {@code List}. + * + * @see MongoCredential#DEFAULT_ALLOWED_HOSTS + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String ALLOWED_HOSTS_KEY = "ALLOWED_HOSTS"; + + /** + * The list of allowed hosts that will be used if no + * {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied. + * The default allowed hosts are: + * {@code "*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"} + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final List DEFAULT_ALLOWED_HOSTS = Collections.unmodifiableList(Arrays.asList( + "*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")); + /** * Creates a MongoCredential instance with an unspecified mechanism. The client will negotiate the best mechanism based on the * version of the server that the client is authenticating to. @@ -365,6 +410,8 @@ public static MongoCredential createAwsCredential(@Nullable final String userNam * @see #withMechanismProperty(String, Object) * @see #PROVIDER_NAME_KEY * @see #OIDC_CALLBACK_KEY + * @see #OIDC_HUMAN_CALLBACK_KEY + * @see #ALLOWED_HOSTS_KEY * @mongodb.server.release 7.0 */ public static MongoCredential createOidcCredential(@Nullable final String userName) { @@ -593,10 +640,15 @@ public String toString() { } /** - * The context for the {@link OidcRequestCallback#onRequest(OidcRequestContext) OIDC request callback}. + * The context for the {@link OidcCallback#onRequest(OidcCallbackContext) OIDC request callback}. */ @Evolving - public interface OidcRequestContext { + public interface OidcCallbackContext { + /** + * @return The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + @Nullable + IdpInfo getIdpInfo(); /** * @return The timeout that this callback must complete within. @@ -607,6 +659,12 @@ public interface OidcRequestContext { * @return The OIDC callback API version. Currently, version 1. */ int getVersion(); + + /** + * @return The OIDC Refresh token supplied by a prior callback invocation. + */ + @Nullable + String getRefreshToken(); } /** @@ -616,27 +674,76 @@ public interface OidcRequestContext { * It does not have to be thread-safe, unless it is provided to multiple * MongoClients. */ - public interface OidcRequestCallback { + public interface OidcCallback { /** * @param context The context. * @return The response produced by an OIDC Identity Provider */ - RequestCallbackResult onRequest(OidcRequestContext context); + OidcCallbackResult onRequest(OidcCallbackContext context); + } + + /** + * The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + @Evolving + public interface IdpInfo { + /** + * @return URL which describes the Authorization Server. This identifier is the + * iss of provided access tokens, and is viable for RFC8414 metadata + * discovery and RFC9207 identification. + */ + String getIssuer(); + + /** + * @return Unique client ID for this OIDC client. + */ + String getClientId(); + + /** + * @return Additional scopes to request from Identity Provider. Immutable. + */ + List getRequestScopes(); } /** * The response produced by an OIDC Identity Provider. */ - public static final class RequestCallbackResult { + public static final class OidcCallbackResult { private final String accessToken; + private final Duration expiresIn; + + @Nullable + private final String refreshToken; + + /** + * @param accessToken The OIDC access token. + * @param expiresIn Time until the access token expires. + * A {@linkplain Duration#isZero() zero-length} duration + * means that the access token does not expire. + */ + public OidcCallbackResult(final String accessToken, final Duration expiresIn) { + this(accessToken, expiresIn, null); + } + /** - * @param accessToken The OIDC access token + * @param accessToken The OIDC access token. + * @param expiresIn Time until the access token expires. + * A {@linkplain Duration#isZero() zero-length} duration + * means that the access token does not expire. + * @param refreshToken The refresh token. If null, refresh will not be attempted. */ - public RequestCallbackResult(final String accessToken) { + public OidcCallbackResult(final String accessToken, final Duration expiresIn, + @Nullable final String refreshToken) { notNull("accessToken", accessToken); + notNull("expiresIn", expiresIn); + if (expiresIn.isNegative()) { + throw new IllegalArgumentException("expiresIn must not be a negative value"); + } this.accessToken = accessToken; + this.expiresIn = expiresIn; + this.refreshToken = refreshToken; } /** @@ -645,5 +752,13 @@ public RequestCallbackResult(final String accessToken) { public String getAccessToken() { return accessToken; } + + /** + * @return The OIDC refresh token. If null, refresh will not be attempted. + */ + @Nullable + public String getRefreshToken() { + return refreshToken; + } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 70f9682476c..6b2362cbc1f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -21,7 +21,7 @@ import com.mongodb.MongoCommandException; import com.mongodb.MongoConfigurationException; import com.mongodb.MongoCredential; -import com.mongodb.MongoCredential.RequestCallbackResult; +import com.mongodb.MongoCredential.OidcCallbackResult; import com.mongodb.MongoException; import com.mongodb.MongoSecurityException; import com.mongodb.ServerAddress; @@ -34,6 +34,7 @@ import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; +import org.bson.RawBsonDocument; import javax.security.sasl.SaslClient; import java.io.IOException; @@ -42,12 +43,18 @@ import java.nio.file.Paths; import java.time.Duration; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; -import static com.mongodb.MongoCredential.OidcRequestCallback; -import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.DEFAULT_ALLOWED_HOSTS; +import static com.mongodb.MongoCredential.IdpInfo; +import static com.mongodb.MongoCredential.OIDC_HUMAN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OidcCallback; +import static com.mongodb.MongoCredential.OidcCallbackContext; import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; import static com.mongodb.assertions.Assertions.assertFalse; @@ -69,6 +76,9 @@ public final class OidcAuthenticator extends SaslAuthenticator { public static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; private static final int CALLBACK_API_VERSION_NUMBER = 1; + @Nullable + private ServerAddress serverAddress; + @Nullable private String connectionLastAccessToken; @@ -94,6 +104,7 @@ public String getMechanismName() { @Override protected SaslClient createSaslClient(final ServerAddress serverAddress) { + this.serverAddress = assertNotNull(serverAddress); MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); return new OidcSaslClient(mongoCredentialWithCache); } @@ -141,11 +152,25 @@ public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument resp speculativeAuthenticateResponse = response; } + private boolean isAutomaticAuthentication() { + return getOidcCallbackMechanismProperty(PROVIDER_NAME_KEY) == null; + } + + private boolean isHumanCallback() { + return getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY) != null; + } + @Nullable - private OidcRequestCallback getRequestCallback() { + private OidcCallback getOidcCallbackMechanismProperty(final String key) { return getMongoCredentialWithCache() .getCredential() - .getMechanismProperty(OIDC_CALLBACK_KEY, null); + .getMechanismProperty(key, null); + } + + @Nullable + private OidcCallback getRequestCallback() { + OidcCallback machine = getOidcCallbackMechanismProperty(OIDC_CALLBACK_KEY); + return machine != null ? machine : getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY); } @Override @@ -195,7 +220,7 @@ private void authenticationLoop(final InternalConnection connection, final Conne try { super.authenticate(connection, description); break; - } catch (MongoSecurityException e) { + } catch (Exception e) { if (triggersRetry(e) && shouldRetryHandler()) { continue; } @@ -219,17 +244,66 @@ private byte[] evaluate(final byte[] challenge) { } byte[][] jwt = new byte[1][]; Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { + OidcCacheEntry oidcCacheEntry = getMongoCredentialWithCache().getOidcCacheEntry(); + String cachedRefreshToken = oidcCacheEntry.getRefreshToken(); + IdpInfo cachedIdpInfo = oidcCacheEntry.getIdpInfo(); String cachedAccessToken = validatedCachedAccessToken(); + OidcCallback requestCallback = assertNotNull(getRequestCallback()); + boolean isHuman = isHumanCallback(); if (cachedAccessToken != null) { - jwt[0] = prepareTokenAsJwt(cachedAccessToken); fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; + jwt[0] = prepareTokenAsJwt(cachedAccessToken); + } else if (cachedRefreshToken != null) { + // cached refresh token is only set when isHuman + // original IDP info will be present, if refresh token present + assertNotNull(cachedIdpInfo); + // Invoke Callback using cached Refresh Token + fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; + OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( + CALLBACK_TIMEOUT, cachedIdpInfo, cachedRefreshToken)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); } else { // cache is empty - OidcRequestCallback requestCallback = assertNotNull(getRequestCallback()); - RequestCallbackResult result = requestCallback.onRequest(new OidcRequestContextImpl(CALLBACK_TIMEOUT)); - jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(result); - fallbackState = FallbackState.PHASE_2_CALLBACK_TOKEN; + + if (!isHuman) { + // no principal request + fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; + OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( + CALLBACK_TIMEOUT)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(null, result); + if (result.getRefreshToken() != null) { + throw new MongoConfigurationException( + "Refresh token must only be provided in human workflow"); + } + } else { + /* + A check for present idp info short-circuits phase-3a. + If a challenge is present, it can only be a response to a + "principal-request", so the challenge must be the resulting + idp info. Such a request is made during speculative auth, + though the source is unimportant, as long as we detect and + use it here. + */ + boolean idpInfoNotPresent = challenge.length == 0; + /* + Checking that the fallback state is not phase-3a ensures that + this does not loop infinitely in the case of a bug. + */ + boolean alreadyTriedPrincipal = fallbackState == FallbackState.PHASE_3A_PRINCIPAL; + if (!alreadyTriedPrincipal && idpInfoNotPresent) { + // request for idp info, only in the human workflow + fallbackState = FallbackState.PHASE_3A_PRINCIPAL; + jwt[0] = prepareUsername(getMongoCredentialWithCache().getCredential().getUserName()); + } else { + IdpInfo idpInfo = toIdpInfo(challenge); + // there is no cached refresh token + fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; + OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( + CALLBACK_TIMEOUT, idpInfo, null)); + jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); + } + } } }); return jwt[0]; @@ -255,19 +329,35 @@ private String validatedCachedAccessToken() { return cachedAccessToken; } - private boolean isAutomaticAuthentication() { - return getRequestCallback() == null; - } - private boolean clientIsComplete() { - return true; // all possibilities are 1-step + return fallbackState != FallbackState.PHASE_3A_PRINCIPAL; } private boolean shouldRetryHandler() { + boolean[] result = new boolean[1]; Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { - validatedCachedAccessToken(); + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { + // a cached access token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken()); + result[0] = true; + } else if (fallbackState == FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN) { + // a refresh token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + result[0] = true; + } else { + // a clean-restart failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + result[0] = false; + } }); - return fallbackState == FallbackState.PHASE_1_CACHED_TOKEN; + return result[0]; } @Nullable @@ -280,24 +370,29 @@ private String getCachedAccessToken() { static final class OidcCacheEntry { @Nullable private final String accessToken; + @Nullable + private final String refreshToken; + @Nullable + private final IdpInfo idpInfo; @Override public String toString() { return "OidcCacheEntry{" + "\n accessToken=[omitted]" + + ",\n refreshToken=[omitted]" + + ",\n idpInfo=" + idpInfo + '}'; } - OidcCacheEntry(final RequestCallbackResult requestCallbackResult) { - this.accessToken = requestCallbackResult.getAccessToken(); - } - OidcCacheEntry() { - this((String) null); + this(null, null, null); } - private OidcCacheEntry(@Nullable final String accessToken) { + private OidcCacheEntry(@Nullable final String accessToken, + @Nullable final String refreshToken, @Nullable final IdpInfo idpInfo) { this.accessToken = accessToken; + this.refreshToken = refreshToken; + this.idpInfo = idpInfo; } @Nullable @@ -305,8 +400,28 @@ String getCachedAccessToken() { return accessToken; } + @Nullable + String getRefreshToken() { + return refreshToken; + } + + @Nullable + IdpInfo getIdpInfo() { + return idpInfo; + } + OidcCacheEntry clearAccessToken() { - return new OidcCacheEntry((String) null); + return new OidcCacheEntry( + null, + this.refreshToken, + this.idpInfo); + } + + OidcCacheEntry clearRefreshToken() { + return new OidcCacheEntry( + this.accessToken, + null, + null); } } @@ -343,13 +458,70 @@ private static String readAwsTokenFromFile() { } } - private byte[] populateCacheWithCallbackResultAndPrepareJwt(@Nullable final RequestCallbackResult requestCallbackResult) { - if (requestCallbackResult == null) { + private byte[] populateCacheWithCallbackResultAndPrepareJwt( + @Nullable final IdpInfo serverInfo, + @Nullable final OidcCallbackResult oidcCallbackResult) { + if (oidcCallbackResult == null) { throw new MongoConfigurationException("Result of callback must not be null"); } - OidcCacheEntry newEntry = new OidcCacheEntry(requestCallbackResult); + OidcCacheEntry newEntry = new OidcCacheEntry(oidcCallbackResult.getAccessToken(), + oidcCallbackResult.getRefreshToken(), serverInfo); getMongoCredentialWithCache().setOidcCacheEntry(newEntry); - return prepareTokenAsJwt(requestCallbackResult.getAccessToken()); + return prepareTokenAsJwt(oidcCallbackResult.getAccessToken()); + } + + private static byte[] prepareUsername(@Nullable final String username) { + BsonDocument document = new BsonDocument(); + if (username != null) { + document = document.append("n", new BsonString(username)); + } + return toBson(document); + } + + private IdpInfo toIdpInfo(final byte[] challenge) { + // validate here to prevent creating IdpInfo for unauthorized hosts + validateAllowedHosts(getMongoCredential()); + BsonDocument c = new RawBsonDocument(challenge); + String issuer = c.getString("issuer").getValue(); + String clientId = c.getString("clientId").getValue(); + return new IdpInfoImpl( + issuer, + clientId, + getStringArray(c, "requestScopes")); + } + + + @Nullable + private static List getStringArray(final BsonDocument document, final String key) { + if (!document.isArray(key)) { + return null; + } + return document.getArray(key).stream() + // ignore non-string values from server, rather than error + .filter(v -> v.isString()) + .map(v -> v.asString().getValue()) + .collect(Collectors.toList()); + } + + private void validateAllowedHosts(final MongoCredential credential) { + List allowedHosts = assertNotNull(credential.getMechanismProperty(ALLOWED_HOSTS_KEY, DEFAULT_ALLOWED_HOSTS)); + String host = assertNotNull(serverAddress).getHost(); + boolean permitted = allowedHosts.stream().anyMatch(allowedHost -> { + if (allowedHost.startsWith("*.")) { + String ending = allowedHost.substring(1); + return host.endsWith(ending); + } else if (allowedHost.contains("*")) { + throw new IllegalArgumentException( + "Allowed host " + allowedHost + " contains invalid wildcard"); + } else { + return host.equals(allowedHost); + } + }); + if (!permitted) { + throw new MongoSecurityException( + credential, "Host " + host + " not permitted by " + ALLOWED_HOSTS_KEY + + ", values: " + allowedHosts); + } } private byte[] prepareTokenAsJwt(final String accessToken) { @@ -400,32 +572,59 @@ public static void validateCreateOidcCredential(@Nullable final char[] password) public static void validateBeforeUse(final MongoCredential credential) { String userName = credential.getUserName(); Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); - Object requestCallback = credential.getMechanismProperty(OIDC_CALLBACK_KEY, null); + Object machineCallback = credential.getMechanismProperty(OIDC_CALLBACK_KEY, null); + Object humanCallback = credential.getMechanismProperty(OIDC_HUMAN_CALLBACK_KEY, null); if (providerName == null) { // callback - if (requestCallback == null) { - throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + " or " - + OIDC_CALLBACK_KEY + " must be specified"); + if (machineCallback == null && humanCallback == null) { + throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + + " or " + OIDC_CALLBACK_KEY + + " or " + OIDC_HUMAN_CALLBACK_KEY + + " must be specified"); + } + if (machineCallback != null && humanCallback != null) { + throw new IllegalArgumentException("Both " + OIDC_CALLBACK_KEY + + " and " + OIDC_HUMAN_CALLBACK_KEY + + " must not be specified"); } } else { - // automatic if (userName != null) { throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified"); } - if (requestCallback != null) { + if (machineCallback != null) { throw new IllegalArgumentException(OIDC_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); } + if (humanCallback != null) { + throw new IllegalArgumentException(OIDC_HUMAN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } } } } @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) - static class OidcRequestContextImpl implements OidcRequestContext { + static class OidcCallbackContextImpl implements OidcCallbackContext { private final Duration timeout; + @Nullable + private final IdpInfo idpInfo; + @Nullable + private final String refreshToken; - OidcRequestContextImpl(final Duration timeout) { + OidcCallbackContextImpl(final Duration timeout) { this.timeout = assertNotNull(timeout); + this.idpInfo = null; + this.refreshToken = null; + } + + OidcCallbackContextImpl(final Duration timeout, final IdpInfo idpInfo, @Nullable final String refreshToken) { + this.timeout = assertNotNull(timeout); + this.idpInfo = assertNotNull(idpInfo); + this.refreshToken = refreshToken; + } + + @Override + public IdpInfo getIdpInfo() { + return idpInfo; } @Override @@ -437,6 +636,41 @@ public Duration getTimeout() { public int getVersion() { return CALLBACK_API_VERSION_NUMBER; } + + @Override + public String getRefreshToken() { + return refreshToken; + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static final class IdpInfoImpl implements IdpInfo { + private final String issuer; + private final String clientId; + private final List requestScopes; + + IdpInfoImpl(final String issuer, final String clientId, @Nullable final List requestScopes) { + this.issuer = assertNotNull(issuer); + this.clientId = assertNotNull(clientId); + this.requestScopes = requestScopes == null + ? Collections.emptyList() + : Collections.unmodifiableList(requestScopes); + } + + @Override + public String getIssuer() { + return issuer; + } + + @Override + public String getClientId() { + return clientId; + } + + @Override + public List getRequestScopes() { + return requestScopes; + } } /** @@ -445,6 +679,8 @@ public int getVersion() { private enum FallbackState { INITIAL, PHASE_1_CACHED_TOKEN, - PHASE_2_CALLBACK_TOKEN + PHASE_2_REFRESH_CALLBACK_TOKEN, + PHASE_3A_PRINCIPAL, + PHASE_3B_CALLBACK_TOKEN } } diff --git a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java index 4da83dc7d4f..cab5b0e0365 100644 --- a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java +++ b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java @@ -119,7 +119,7 @@ private MongoCredential getMongoCredential() { if ("oidcRequest".equals(string)) { credential = credential.withMechanismProperty( OIDC_CALLBACK_KEY, - (MongoCredential.OidcRequestCallback) (context) -> null); + (MongoCredential.OidcCallback) (context) -> null); } else { fail("Unsupported callback: " + string); } @@ -176,7 +176,7 @@ private void assertMechanismProperties(final MongoCredential credential) { } else if ((document.get(key).isBoolean())) { boolean expectedValue = document.getBoolean(key).getValue(); if (OIDC_CALLBACK_KEY.equals(key)) { - assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRequestCallback); + assertTrue(actualMechanismProperty instanceof MongoCredential.OidcCallback); return; } assertNotNull(actualMechanismProperty); diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java index 5f42066aada..7e83f802279 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java @@ -73,7 +73,6 @@ import org.bson.BsonDouble; import org.bson.BsonInt32; import org.bson.BsonInt64; -import org.bson.BsonNumber; import org.bson.BsonString; import org.bson.BsonValue; @@ -82,6 +81,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -542,7 +542,7 @@ private void initClient(final BsonDocument entity, final String id, if (isOidc && hasPlaceholder) { clientSettingsBuilder.credential(credential.withMechanismProperty( MongoCredential.OIDC_CALLBACK_KEY, - (MongoCredential.OidcRequestCallback) context -> { + (MongoCredential.OidcCallback) context -> { Path path = Paths.get(getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE)); String accessToken; try { @@ -550,7 +550,7 @@ private void initClient(final BsonDocument entity, final String id, } catch (IOException e) { throw new RuntimeException(e); } - return new MongoCredential.RequestCallbackResult(accessToken); + return new MongoCredential.OidcCallbackResult(accessToken, Duration.ZERO); })); break; } diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index 66b6a305297..b5a87a51cef 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -16,13 +16,13 @@ package com.mongodb.internal.connection; -import com.mongodb.ClusterFixture; import com.mongodb.ConnectionString; import com.mongodb.MongoClientSettings; import com.mongodb.MongoCommandException; import com.mongodb.MongoConfigurationException; import com.mongodb.MongoCredential; -import com.mongodb.MongoCredential.RequestCallbackResult; +import com.mongodb.MongoSecurityException; +import com.mongodb.MongoSocketException; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.TestListener; @@ -49,21 +49,27 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Random; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.stream.Collectors; -import static com.mongodb.MongoCredential.OidcRequestCallback; -import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.OIDC_HUMAN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OidcCallbackResult; +import static com.mongodb.MongoCredential.OidcCallback; +import static com.mongodb.MongoCredential.OidcCallbackContext; import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; import static java.lang.System.getenv; import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; import static util.ThreadTestHelpers.executeAll; @@ -81,15 +87,37 @@ public static boolean oidcTestsEnabled() { private String appName; protected static String getOidcUri() { - ConnectionString cs = ClusterFixture.getConnectionString(); - // remove username and password + ConnectionString cs = new ConnectionString(getenv("OIDC_ATLAS_URI_SINGLE")); + // remove any username and password return "mongodb+srv://" + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; } + protected static String getOidcUri(final String username) { + ConnectionString cs = new ConnectionString(getenv("OIDC_ATLAS_URI_SINGLE")); + // set username + return "mongodb+srv://" + username + "@" + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; + } + + protected static String getOidcUriMulti(@Nullable final String username) { + ConnectionString cs = new ConnectionString(getenv("OIDC_ATLAS_URI_MULTI")); + // set username + String userPart = username == null ? "" : username + "@"; + return "mongodb+srv://" + userPart + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; + } + private static String getAwsOidcUri() { return getOidcUri() + "&authMechanismProperties=PROVIDER_NAME:aws"; } + @NotNull + private static String oidcTokenDirectory() { + return getenv("OIDC_TOKEN_DIR"); + } + + private static String getAwsTokenFilePath() { + return getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE); + } + protected MongoClient createMongoClient(final MongoClientSettings settings) { return MongoClients.create(settings); } @@ -147,7 +175,7 @@ public void test2p1ValidCallbackInputs() { TestCallback onRequest = createCallback(); // #. Verify that the request callback was called with the appropriate // inputs, including the timeout parameter if possible. - OidcRequestCallback onRequest2 = (context) -> { + OidcCallback onRequest2 = (context) -> { assertEquals(expectedSeconds, context.getTimeout()); return onRequest.onRequest(context); }; @@ -162,7 +190,7 @@ public void test2p1ValidCallbackInputs() { @Test public void test2p2RequestCallbackReturnsNull() { //noinspection ConstantConditions - OidcRequestCallback onRequest = (context) -> null; + OidcCallback onRequest = (context) -> null; MongoClientSettings settings = this.createSettings(getOidcUri(), onRequest, null); performFind(settings, MongoConfigurationException.class, "Result of callback must not be null"); } @@ -171,9 +199,9 @@ public void test2p2RequestCallbackReturnsNull() { public void test2p3CallbackReturnsMissingData() { // #. Create a client with a request callback that returns data not // conforming to the OIDCRequestTokenResult with missing field(s). - OidcRequestCallback onRequest = (context) -> { + OidcCallback onRequest = (context) -> { //noinspection ConstantConditions - return new RequestCallbackResult(null); + return new OidcCallbackResult(null, Duration.ZERO); }; // we ensure that the error is propagated MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); @@ -205,8 +233,8 @@ public void test2p4InvalidClientConfigurationWithCallback() { public void test3p1AuthFailsWithCachedToken() throws ExecutionException, InterruptedException, NoSuchFieldException, IllegalAccessException { TestCallback onRequestWrapped = createCallback(); CompletableFuture poisonToken = new CompletableFuture<>(); - OidcRequestCallback onRequest = (context) -> { - RequestCallbackResult result = onRequestWrapped.onRequest(context); + OidcCallback onRequest = (context) -> { + OidcCallbackResult result = onRequestWrapped.onRequest(context); String accessToken = result.getAccessToken(); if (!poisonToken.isDone()) { poisonToken.complete(accessToken); @@ -240,7 +268,7 @@ public void test3p1AuthFailsWithCachedToken() throws ExecutionException, Interru @Test public void test3p2AuthFailsWithoutCachedToken() { MongoClientSettings clientSettings = createSettings(getOidcUri(), - (x) -> new RequestCallbackResult("invalid_token"), null); + (x) -> new OidcCallbackResult("invalid_token", Duration.ZERO), null); try (MongoClient mongoClient = createMongoClient(clientSettings)) { try { performFind(mongoClient); @@ -252,7 +280,6 @@ public void test3p2AuthFailsWithoutCachedToken() { } } - @Test public void test4p1Reauthentication() { TestCallback onRequest = createCallback(); @@ -265,19 +292,328 @@ public void test4p1Reauthentication() { assertEquals(2, onRequest.invocations.get()); } + // Tests for human authentication ("testh", to preserve ordering) + + @Test + public void testh1p1SinglePrincipalImplicitUsername() { + // #. Create default OIDC client with authMechanism=MONGODB-OIDC. + String oidcUri = getOidcUri(); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings); + assertEquals(1, callback.invocations.get()); + } + + @Test + public void testh1p2SinglePrincipalExplicitUsername() { + // #. Create a client with MONGODB_URI_SINGLE, a username of test_user1, + // authMechanism=MONGODB-OIDC, and the OIDC human callback. + String oidcUri = getOidcUri("test_user1"); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings); + } + + @Test + public void testh1p3MultiplePrincipalUser1() { + // #. Create a client with MONGODB_URI_MULTI, a username of test_user1, + // authMechanism=MONGODB-OIDC, and the OIDC human callback. + String oidcUri = getOidcUriMulti("test_user1"); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings); + } + + @Test + public void testh1p4MultiplePrincipalUser2() { + //- Create a human callback that reads in the generated ``test_user2`` token file. + //- Create a client with ``MONGODB_URI_MULTI``, a username of ``test_user2``, + // ``authMechanism=MONGODB-OIDC``, and the OIDC human callback. + String oidcUri = getOidcUriMulti("test_user2"); + TestCallback callback = createHumanCallback() + .setPathSupplier(() -> tokenQueue("test_user2").remove()); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings); + } + + @Test + public void testh1p5MultiplePrincipalNoUser() { + //- Create a client with ``MONGODB_URI_MULTI``, no username, + // ``authMechanism=MONGODB-OIDC``, and the OIDC human callback. + String oidcUri = getOidcUriMulti(null); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + // #. Perform a find operation that succeeds + performFind(clientSettings, MongoCommandException.class, "Authentication failed"); + } + + @Test + public void testh1p6AllowedHostsBlocked() { + //- Create a default OIDC client, with an ``ALLOWED_HOSTS`` that is an empty list. + //- Assert that a ``find`` operation fails with a client-side error. + MongoClientSettings settings1 = createSettings( + getOidcUri(), + createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Collections.emptyList()); + performFind(settings1, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); + + //- Create a client that uses the URL + // ``mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com``, a + // human callback, and an ``ALLOWED_HOSTS`` that contains ``["example.com"]``. + //- Assert that a ``find`` operation fails with a client-side error. + MongoClientSettings settings2 = createSettings( + getOidcUri() + "&ignored=example.com", + createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Arrays.asList("example.com")); + performFind(settings2, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); + } + + // Not a prose test + @Test + public void testAllowedHostsDisallowedInConnectionString() { + String string = "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:localhost"; + assertCause(IllegalArgumentException.class, + "connection string contains disallowed mechanism properties", + () -> new ConnectionString(string)); + } + + @Test + public void testh2p1ValidCallbackInputs() { + TestCallback onRequest = createHumanCallback(); + OidcCallback onRequest2 = (context) -> { + assertTrue(context.getIdpInfo().getClientId().startsWith("0oad")); + assertTrue(context.getIdpInfo().getIssuer().endsWith("mock-identity-config-oidc")); + assertEquals(Arrays.asList("fizz", "buzz"), context.getIdpInfo().getRequestScopes()); + assertEquals(Duration.ofMinutes(5), context.getTimeout()); + return onRequest.onRequest(context); + }; + MongoClientSettings clientSettings = createHumanSettings(getOidcUri(), onRequest2, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + // Ensure that callback was called + assertEquals(1, onRequest.getInvocations()); + } + } + + @Test + public void testh2p2HumanCallbackReturnsMissingData() { + //noinspection ConstantConditions + OidcCallback onRequestNull = (context) -> null; + performFind(createHumanSettings(getOidcUri(), onRequestNull, null), + MongoConfigurationException.class, + "Result of callback must not be null"); + + //noinspection ConstantConditions + OidcCallback onRequest = (context) -> new OidcCallbackResult(null, Duration.ZERO); + performFind(createHumanSettings(getOidcUri(), onRequest, null), + IllegalArgumentException.class, + "accessToken can not be null"); + + // additionally, check validation for refresh in machine workflow: + OidcCallback onRequestMachineRefresh = (context) -> new OidcCallbackResult("access", Duration.ZERO, "exists"); + performFind(createSettings(getOidcUri(), onRequestMachineRefresh, null), + MongoConfigurationException.class, + "Refresh token must only be provided in human workflow"); + } + + @Test + public void testh3p1UsesSpecAuthIfCachedToken() { + failCommandAndCloseConnection("find", 1); + MongoClientSettings settings = createHumanSettings(getOidcUri(), createHumanCallback(), null); + + try (MongoClient mongoClient = createMongoClient(settings)) { + assertCause(MongoSocketException.class, + "Prematurely reached end of stream", + () -> performFind(mongoClient)); + failCommand(20, 99, "saslStart"); + + performFind(mongoClient); + } + } + + @Test + public void testh3p2NoSpecAuthIfNoCachedToken() { + failCommand(20, 99, "saslStart"); + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + performFind(createHumanSettings(getOidcUri(), createHumanCallback(), commandListener), + MongoCommandException.class, + "Command failed with error 20"); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "saslStart started", + "saslStart failed" + ), listener.getEventStrings()); + listener.clear(); + } + + @Test + public void testh4p1Succeeds() { + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + TestCallback callback = createHumanCallback() + .setEventListener(listener); + MongoClientSettings settings = createHumanSettings(getOidcUri(), callback, commandListener); + try (MongoClient mongoClient = createMongoClient(settings)) { + performFind(mongoClient); + listener.clear(); + assertEquals(1, callback.getInvocations()); + + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(Arrays.asList( + // first find fails: + "find started", + "find failed", + "onRequest invoked (Refresh Token: present - IdpInfo: present)", + "read access token: test_user1", + "saslStart started", + "saslStart succeeded", + // second find succeeds: + "find started", + "find succeeded" + ), listener.getEventStrings()); + assertEquals(2, callback.getInvocations()); + } + } + + @Test + public void testh4p2SucceedsNoRefresh() { + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + TestCallback callback = createHumanCallback().setEventListener(listener); + MongoClientSettings settings = createHumanSettings(getOidcUri(), callback, commandListener); + try (MongoClient mongoClient = createMongoClient(settings)) { + + performFind(mongoClient); + listener.clear(); + assertEquals(1, callback.getInvocations()); + + failCommand(391, 1, "find"); + performFind(mongoClient); + } + } + + + // TODO-OIDC awaiting spec updates, add 4.3 and 4.4 + + // Not a prose test + @Test + public void testErrorClearsCache() { + // #. Create a new client with a valid request callback that + // gives credentials that expire within 5 minutes and + // a refresh callback that gives invalid credentials. + TestListener listener = new TestListener(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires", + "test_user1_1"); + TestCallback onRequest = createHumanCallback() + .setRefreshToken("refresh") + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings clientSettings = createHumanSettings(getOidcUri(), onRequest, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Ensure that a find operation adds a new entry to the cache. + performFind(mongoClient); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + // no speculative auth. Send principal request: + "saslStart started", + "saslStart succeeded", + "onRequest invoked (Refresh Token: none - IdpInfo: present)", + "read access token: test_user1", + // the refresh token from the callback is cached here + // send jwt: + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that a subsequent find operation results in a 391 error. + failCommand(391, 1, "find"); + // ensure that the operation entirely fails, after attempting both potential fallback callbacks + assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); + assertEquals(Arrays.asList( + "find started", + "find failed", // reauth 391; current access token is invalid + // fall back to refresh token, from prior find + "onRequest invoked (Refresh Token: present - IdpInfo: present)", + "read access token: test_user1_expires", + "saslStart started", + "saslStart failed", // it is expired, fails immediately + // fall back to principal request, and non-refresh callback: + "saslStart started", + "saslStart succeeded", + "onRequest invoked (Refresh Token: none - IdpInfo: present)", + "read access token: test_user1_expires", + "saslContinue started", + "saslContinue failed" // also fails due to 391 + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that the cache value cleared. + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(Arrays.asList( + "find started", + "find failed", + // falling back to principal request, onRequest callback. + // this implies that the cache has been cleared during the + // preceding find operation. + "saslStart started", + "saslStart succeeded", + "onRequest invoked (Refresh Token: none - IdpInfo: present)", + "read access token: test_user1_1", + "saslContinue started", + "saslContinue succeeded", + // auth has finished + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + } + } + public MongoClientSettings createSettings( final String connectionString, - @Nullable final OidcRequestCallback onRequest) { + @Nullable final OidcCallback onRequest) { return createSettings(connectionString, onRequest, null); } private MongoClientSettings createSettings( final String connectionString, - @Nullable final OidcRequestCallback onRequest, + @Nullable final OidcCallback callback, @Nullable final CommandListener commandListener) { + return createSettings(connectionString, callback, commandListener, OIDC_CALLBACK_KEY); + } + + private MongoClientSettings createHumanSettings( + final String connectionString, + @Nullable final OidcCallback callback, + @Nullable final CommandListener commandListener) { + return createSettings(connectionString, callback, commandListener, OIDC_HUMAN_CALLBACK_KEY); + } + + @NotNull + private MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcCallback onRequest, + @Nullable final CommandListener commandListener, + final String oidcCallbackKey) { ConnectionString cs = new ConnectionString(connectionString); MongoCredential credential = cs.getCredential() - .withMechanismProperty(OIDC_CALLBACK_KEY, onRequest); + .withMechanismProperty(oidcCallbackKey, onRequest); MongoClientSettings.Builder builder = MongoClientSettings.builder() .applicationName(appName) .applyConnectionString(cs) @@ -289,6 +625,26 @@ private MongoClientSettings createSettings( return builder.build(); } + private MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcCallback onRequest, + @Nullable final CommandListener commandListener, + final String oidcCallbackKey, + @Nullable final List allowedHosts) { + ConnectionString cs = new ConnectionString(connectionString); + MongoCredential credential = cs.getCredential() + .withMechanismProperty(oidcCallbackKey, onRequest) + .withMechanismProperty(ALLOWED_HOSTS_KEY, allowedHosts); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applicationName(appName) + .applyConnectionString(cs) + .credential(credential); + if (commandListener != null) { + builder.addCommandListener(commandListener); + } + return builder.build(); + } + private void performFind(final MongoClientSettings settings) { try (MongoClient mongoClient = createMongoClient(settings)) { performFind(mongoClient); @@ -333,7 +689,8 @@ private static void assertCause( } protected void delayNextFind() { - try (MongoClient client = createMongoClient(createSettings(getAwsOidcUri(), null, null))) { + try (MongoClient client = createMongoClient(createSettings( + getAwsOidcUri(), null, null))) { BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(1))) .append("data", new BsonDocument() @@ -359,11 +716,27 @@ protected void failCommand(final int code, final int times, final String... comm } } - public static class TestCallback implements OidcRequestCallback { + private void failCommandAndCloseConnection(final String command, final int times) { + try (MongoClient mongoClient = createMongoClient(createSettings( + getAwsOidcUri(), null, null))) { + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(times))) + .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) + .append("closeConnection", new BsonBoolean(true)) + .append("failCommands", new BsonArray(Arrays.asList(new BsonString(command)))) + ); + mongoClient.getDatabase("admin").runCommand(failPointDocument); + } + } + + public static class TestCallback implements OidcCallback { private final AtomicInteger invocations = new AtomicInteger(); @Nullable private final Integer delayInMilliseconds; @Nullable + private final String refreshToken; + @Nullable private final AtomicInteger concurrentTracker; @Nullable private final TestListener testListener; @@ -371,14 +744,16 @@ public static class TestCallback implements OidcRequestCallback { private final Supplier pathSupplier; public TestCallback() { - this(null, new AtomicInteger(), null, null); + this(null, null, new AtomicInteger(), null, null); } public TestCallback( + @Nullable final String refreshToken, @Nullable final Integer delayInMilliseconds, @Nullable final AtomicInteger concurrentTracker, @Nullable final TestListener testListener, @Nullable final Supplier pathSupplier) { + this.refreshToken = refreshToken; this.delayInMilliseconds = delayInMilliseconds; this.concurrentTracker = concurrentTracker; this.testListener = testListener; @@ -390,15 +765,18 @@ public int getInvocations() { } @Override - public RequestCallbackResult onRequest(final OidcRequestContext context) { + public OidcCallbackResult onRequest(final OidcCallbackContext context) { if (testListener != null) { - testListener.add("onRequest invoked"); + testListener.add("onRequest invoked (" + + "Refresh Token: " + (context.getRefreshToken() == null ? "none" : "present") + + " - IdpInfo: " + (context.getIdpInfo() == null ? "none" : "present") + + ")"); } return callback(); } @NotNull - private RequestCallbackResult callback() { + private OidcCallbackResult callback() { if (concurrentTracker != null) { if (concurrentTracker.get() > 0) { throw new RuntimeException("Callbacks should not be invoked by multiple threads."); @@ -408,7 +786,7 @@ private RequestCallbackResult callback() { try { invocations.incrementAndGet(); Path path = Paths.get(pathSupplier == null - ? getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE) + ? getAwsTokenFilePath() : pathSupplier.get()); String accessToken; try { @@ -420,7 +798,7 @@ private RequestCallbackResult callback() { if (testListener != null) { testListener.add("read access token: " + path.getFileName()); } - return new RequestCallbackResult(accessToken); + return new OidcCallbackResult(accessToken, Duration.ZERO, refreshToken); } finally { if (concurrentTracker != null) { concurrentTracker.decrementAndGet(); @@ -436,6 +814,7 @@ private void simulateDelay() throws InterruptedException { public TestCallback setDelayMs(final int milliseconds) { return new TestCallback( + this.refreshToken, milliseconds, this.concurrentTracker, this.testListener, @@ -444,6 +823,7 @@ public TestCallback setDelayMs(final int milliseconds) { public TestCallback setConcurrentTracker(final AtomicInteger c) { return new TestCallback( + this.refreshToken, this.delayInMilliseconds, c, this.testListener, @@ -452,6 +832,7 @@ public TestCallback setConcurrentTracker(final AtomicInteger c) { public TestCallback setEventListener(final TestListener testListener) { return new TestCallback( + this.refreshToken, this.delayInMilliseconds, this.concurrentTracker, testListener, @@ -460,14 +841,38 @@ public TestCallback setEventListener(final TestListener testListener) { public TestCallback setPathSupplier(final Supplier pathSupplier) { return new TestCallback( + this.refreshToken, this.delayInMilliseconds, this.concurrentTracker, this.testListener, pathSupplier); } + public TestCallback setRefreshToken(final String token) { + return new TestCallback( + token, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + } + + @NotNull + private ConcurrentLinkedQueue tokenQueue(final String... queue) { + String tokenPath = oidcTokenDirectory(); + return java.util.stream.Stream + .of(queue) + .map(v -> tokenPath + v) + .collect(Collectors.toCollection(ConcurrentLinkedQueue::new)); } public TestCallback createCallback() { return new TestCallback(); } + + public TestCallback createHumanCallback() { + return new TestCallback() + .setPathSupplier(() -> oidcTokenDirectory() + "test_user1") + .setRefreshToken("refreshToken"); + } } From 106ee4dcab9bc6202238aba865fbb385c3389a22 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 29 Apr 2024 15:00:08 -0600 Subject: [PATCH 5/6] OIDC Add remaining environments (azure, gcp), evergreen testing, API naming updates (#1371) JAVA-5353 JAVA-5395 JAVA-4834 JAVA-4932 --------- Co-authored-by: Valentin Kovalenko --- .evergreen/.evg.yml | 152 ++++- .evergreen/run-mongodb-oidc-test.sh | 40 ++ .../main/com/mongodb/ConnectionString.java | 48 +- .../src/main/com/mongodb/MongoCredential.java | 84 ++- .../src/main/com/mongodb/internal/Locks.java | 20 +- .../authentication/AzureCredentialHelper.java | 71 ++- .../authentication/CredentialInfo.java | 44 ++ .../authentication/GcpCredentialHelper.java | 13 + .../connection/OidcAuthenticator.java | 179 ++++-- .../auth/legacy/connection-string.json | 153 ++++- .../auth/mongodb-oidc-no-retry.json | 41 +- .../ConnectionStringSpecification.groovy | 2 +- .../com/mongodb/ConnectionStringUnitTest.java | 47 ++ .../com/mongodb/client/unified/Entities.java | 70 +- .../unified/RunOnRequirementsMatcher.java | 5 +- .../OidcAuthenticationProseTests.java | 600 ++++++++++++------ 16 files changed, 1186 insertions(+), 383 deletions(-) create mode 100755 .evergreen/run-mongodb-oidc-test.sh create mode 100644 driver-core/src/main/com/mongodb/internal/authentication/CredentialInfo.java diff --git a/.evergreen/.evg.yml b/.evergreen/.evg.yml index d35c01fd89f..886282b77c4 100644 --- a/.evergreen/.evg.yml +++ b/.evergreen/.evg.yml @@ -12,9 +12,8 @@ stepback: true # Actual testing tasks are marked with `type: test` command_type: system -# Protect ourself against rogue test case, or curl gone wild, that runs forever -# 12 minutes is the longest we'll ever run -exec_timeout_secs: 3600 # 12 minutes is the longest we'll ever run +# Protect ourselves against rogue test case, or curl gone wild, that runs forever +exec_timeout_secs: 3600 # What to do when evergreen hits the timeout (`post:` tasks are run automatically) timeout: @@ -968,6 +967,60 @@ tasks: - func: "run load-balancer" - func: "run load-balancer tests" + - name: "oidc-auth-test" + commands: + - command: subprocess.exec + type: test + params: + working_dir: "src" + binary: bash + include_expansions_in_env: ["DRIVERS_TOOLS", "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + env: + OIDC_ENV: "test" + args: + - .evergreen/run-mongodb-oidc-test.sh + + - name: "oidc-auth-test-azure" + commands: + - command: shell.exec + params: + shell: bash + env: + JAVA_HOME: ${JAVA_HOME} + script: |- + set -o errexit + ${PREPARE_SHELL} + cd src + git add . + git commit -m "add files" + # uncompressed tar used to allow appending .git folder + export AZUREOIDC_DRIVERS_TAR_FILE=/tmp/mongo-java-driver.tar + git archive -o $AZUREOIDC_DRIVERS_TAR_FILE HEAD + tar -rf $AZUREOIDC_DRIVERS_TAR_FILE .git + export AZUREOIDC_TEST_CMD="OIDC_ENV=azure ./.evergreen/run-mongodb-oidc-test.sh" + bash $DRIVERS_TOOLS/.evergreen/auth_oidc/azure/run-driver-test.sh + + - name: "oidc-auth-test-gcp" + commands: + - command: shell.exec + params: + shell: bash + script: |- + set -o errexit + ${PREPARE_SHELL} + cd src + git add . + git commit -m "add files" + # uncompressed tar used to allow appending .git folder + export GCPOIDC_DRIVERS_TAR_FILE=/tmp/mongo-java-driver.tar + git archive -o $GCPOIDC_DRIVERS_TAR_FILE HEAD + tar -rf $GCPOIDC_DRIVERS_TAR_FILE .git + # Define the command to run on the VM. + # Ensure that we source the environment file created for us, set up any other variables we need, + # and then run our test suite on the vm. + export GCPOIDC_TEST_CMD="OIDC_ENV=gcp ./.evergreen/run-mongodb-oidc-test.sh" + bash $DRIVERS_TOOLS/.evergreen/auth_oidc/gcp/run-driver-test.sh + - name: serverless-test commands: - func: "run serverless" @@ -2065,6 +2118,78 @@ task_groups: tasks: - test-aws-lambda-deployed + - name: testoidc_task_group + setup_group: + - func: fetch source + - func: prepare resources + - func: fix absolute paths + - command: ec2.assume_role + params: + role_arn: ${aws_test_secrets_role} + - command: subprocess.exec + params: + binary: bash + include_expansions_in_env: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"] + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/setup.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/teardown.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test + + - name: testazureoidc_task_group + setup_group: + - func: fetch source + - func: prepare resources + - func: fix absolute paths + - command: subprocess.exec + params: + binary: bash + env: + AZUREOIDC_VMNAME_PREFIX: "JAVA_DRIVER" + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/azure/create-and-setup-vm.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/azure/delete-vm.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-azure + + - name: testgcpoidc_task_group + setup_group: + - func: fetch source + - func: prepare resources + - func: fix absolute paths + - command: subprocess.exec + params: + binary: bash + env: + GCPOIDC_VMNAME_PREFIX: "JAVA_DRIVER" + GCPKMS_MACHINETYPE: "e2-medium" # comparable elapsed time to Azure; default was starved, caused timeouts + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/setup.sh + teardown_task: + - command: subprocess.exec + params: + binary: bash + args: + - ${DRIVERS_TOOLS}/.evergreen/auth_oidc/gcp/teardown.sh + setup_group_can_fail_task: true + setup_group_timeout_secs: 1800 + tasks: + - oidc-auth-test-gcp + buildvariants: # Test packaging and other release related routines @@ -2216,6 +2341,27 @@ buildvariants: tasks: - name: "test_atlas_task_group_search_indexes" +- name: "oidc-auth-test" + display_name: "OIDC Auth" + run_on: ubuntu2204-small + tasks: + - name: testoidc_task_group + batchtime: 20160 # 14 days + +- name: testazureoidc-variant + display_name: "OIDC Auth Azure" + run_on: ubuntu2204-small + tasks: + - name: testazureoidc_task_group + batchtime: 20160 # 14 days + +- name: testgcpoidc-variant + display_name: "OIDC Auth GCP" + run_on: ubuntu2204-small + tasks: + - name: testgcpoidc_task_group + batchtime: 20160 # 14 days + - matrix_name: "aws-auth-test" matrix_spec: { ssl: "nossl", jdk: ["jdk8", "jdk17", "jdk21"], version: ["4.4", "5.0", "6.0", "7.0", "latest"], os: "ubuntu", aws-credential-provider: "*" } diff --git a/.evergreen/run-mongodb-oidc-test.sh b/.evergreen/run-mongodb-oidc-test.sh new file mode 100755 index 00000000000..1f5c1b310cc --- /dev/null +++ b/.evergreen/run-mongodb-oidc-test.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +set +x # Disable debug trace +set -eu + +echo "Running MONGODB-OIDC authentication tests" +echo "OIDC_ENV $OIDC_ENV" + +if [ $OIDC_ENV == "test" ]; then + if [ -z "$DRIVERS_TOOLS" ]; then + echo "Must specify DRIVERS_TOOLS" + exit 1 + fi + source ${DRIVERS_TOOLS}/.evergreen/auth_oidc/secrets-export.sh + # java will not need to be installed, but we need to config + RELATIVE_DIR_PATH="$(dirname "${BASH_SOURCE:-$0}")" + source "${RELATIVE_DIR_PATH}/javaConfig.bash" +elif [ $OIDC_ENV == "azure" ]; then + source ./env.sh +elif [ $OIDC_ENV == "gcp" ]; then + source ./secrets-export.sh +else + echo "Unrecognized OIDC_ENV $OIDC_ENV" + exit 1 +fi + + +if ! which java ; then + echo "Installing java..." + sudo apt install openjdk-17-jdk -y + echo "Installed java." +fi + +which java +export OIDC_TESTS_ENABLED=true + +./gradlew -Dorg.mongodb.test.uri="$MONGODB_URI" \ + --stacktrace --debug --info --no-build-cache driver-core:cleanTest \ + driver-sync:test --tests OidcAuthenticationProseTests --tests UnifiedAuthTest \ + driver-reactive-streams:test --tests OidcAuthenticationAsyncProseTests \ diff --git a/driver-core/src/main/com/mongodb/ConnectionString.java b/driver-core/src/main/com/mongodb/ConnectionString.java index 8bb802e9e70..ae795a65bba 100644 --- a/driver-core/src/main/com/mongodb/ConnectionString.java +++ b/driver-core/src/main/com/mongodb/ConnectionString.java @@ -38,6 +38,7 @@ import java.net.URLDecoder; import java.nio.charset.StandardCharsets; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -229,9 +230,9 @@ * *

Authentication configuration:

*
    - *
  • {@code authMechanism=MONGO-CR|GSSAPI|PLAIN|MONGODB-X509}: The authentication mechanism to use if a credential was supplied. + *
  • {@code authMechanism=MONGO-CR|GSSAPI|PLAIN|MONGODB-X509|MONGODB-OIDC}: The authentication mechanism to use if a credential was supplied. * The default is unspecified, in which case the client will pick the most secure mechanism available based on the sever version. For the - * GSSAPI and MONGODB-X509 mechanisms, no password is accepted, only the username. + * GSSAPI, MONGODB-X509, and MONGODB-OIDC mechanisms, no password is accepted, only the username. *
  • *
  • {@code authSource=string}: The source of the authentication credentials. This is typically the database that * the credentials have been created. The value defaults to the database specified in the path portion of the connection string. @@ -239,7 +240,9 @@ * mechanism (the default). *
  • *
  • {@code authMechanismProperties=PROPERTY_NAME:PROPERTY_VALUE,PROPERTY_NAME2:PROPERTY_VALUE2}: This option allows authentication - * mechanism properties to be set on the connection string. + * mechanism properties to be set on the connection string. Property values must be percent-encoded individually, when + * separator or escape characters are used (including {@code ,} (comma), {@code =}, {@code +}, {@code &}, and {@code %}). The + * entire substring following the {@code =} should not itself be encoded. *
  • *
  • {@code gssapiServiceName=string}: This option only applies to the GSSAPI mechanism and is used to alter the service name. * Deprecated, please use {@code authMechanismProperties=SERVICE_NAME:string} instead. @@ -916,13 +919,16 @@ private MongoCredential createCredentials(final Map> option if (credential != null && authMechanismProperties != null) { for (String part : authMechanismProperties.split(",")) { - String[] mechanismPropertyKeyValue = part.split(":"); + String[] mechanismPropertyKeyValue = part.split(":", 2); if (mechanismPropertyKeyValue.length != 2) { throw new IllegalArgumentException(format("The connection string contains invalid authentication properties. " + "'%s' is not a key value pair", part)); } String key = mechanismPropertyKeyValue[0].trim().toLowerCase(); String value = mechanismPropertyKeyValue[1].trim(); + if (decodeValueOfKeyValuePair(credential.getMechanism())) { + value = urldecode(value); + } if (MECHANISM_KEYS_DISALLOWED_IN_CONNECTION_STRING.contains(key)) { throw new IllegalArgumentException(format("The connection string contains disallowed mechanism properties. " + "'%s' must be set on the credential programmatically.", key)); @@ -938,6 +944,27 @@ private MongoCredential createCredentials(final Map> option return credential; } + private static boolean decodeWholeOptionValue(final boolean isOidc, final String key) { + // The "whole option value" is the entire string following = in an option, + // including separators when the value is a list or list of key-values. + // This is the original parsing behaviour, but implies that users can + // encode separators (much like they might with URL parameters). This + // behaviour implies that users cannot encode "key-value" values that + // contain a comma, because this will (after this "whole value decoding) + // be parsed as a key-value separator, rather than part of a value. + return !(isOidc && key.equals("authmechanismproperties")); + } + + private static boolean decodeValueOfKeyValuePair(@Nullable final String mechanismName) { + // Only authMechanismProperties should be individually decoded, and only + // when the mechanism is OIDC. These will not have been decoded. + return AuthenticationMechanism.MONGODB_OIDC.getMechanismName().equals(mechanismName); + } + + private static boolean isOidc(final List options) { + return options.contains("authMechanism=" + AuthenticationMechanism.MONGODB_OIDC.getMechanismName()); + } + private MongoCredential createMongoCredentialWithMechanism(final AuthenticationMechanism mechanism, final String userName, @Nullable final char[] password, @Nullable final String authSource, @@ -1018,12 +1045,14 @@ private String getLastValue(final Map> optionsMap, final St private Map> parseOptions(final String optionsPart) { Map> optionsMap = new HashMap<>(); - if (optionsPart.length() == 0) { + if (optionsPart.isEmpty()) { return optionsMap; } - for (final String part : optionsPart.split("&|;")) { - if (part.length() == 0) { + List options = Arrays.asList(optionsPart.split("&|;")); + boolean isOidc = isOidc(options); + for (final String part : options) { + if (part.isEmpty()) { continue; } int idx = part.indexOf("="); @@ -1034,7 +1063,10 @@ private Map> parseOptions(final String optionsPart) { if (valueList == null) { valueList = new ArrayList<>(1); } - valueList.add(urldecode(value)); + if (decodeWholeOptionValue(isOidc, key)) { + value = urldecode(value); + } + valueList.add(value); optionsMap.put(key, valueList); } else { throw new IllegalArgumentException(format("The connection string contains an invalid option '%s'. " diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index 295803e55a4..e085ac074f0 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -37,6 +37,7 @@ import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_1; import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_256; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateCreateOidcCredential; import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateOidcCredentialConstruction; /** @@ -185,7 +186,13 @@ public final class MongoCredential { public static final String AWS_CREDENTIAL_PROVIDER_KEY = "AWS_CREDENTIAL_PROVIDER"; /** - * The provider name. The value must be a string. + * Mechanism property key for specifying the environment for OIDC, which is + * the name of a built-in OIDC application environment integration to use + * to obtain credentials. The value must be either "gcp" or "azure". + * This is an alternative to supplying a callback. + *

    + * The "gcp" and "azure" environments require + * {@link MongoCredential#TOKEN_RESOURCE_KEY} to be specified. *

    * If this is provided, * {@link MongoCredential#OIDC_CALLBACK_KEY} and @@ -193,51 +200,54 @@ public final class MongoCredential { * must not be provided. * * @see #createOidcCredential(String) - * @since 4.10 + * @see MongoCredential#TOKEN_RESOURCE_KEY + * @since 5.1 */ - public static final String PROVIDER_NAME_KEY = "PROVIDER_NAME"; + public static final String ENVIRONMENT_KEY = "ENVIRONMENT"; /** + * Mechanism property key for the OIDC callback. * This callback is invoked when the OIDC-based authenticator requests * a token. The type of the value must be {@link OidcCallback}. * {@link IdpInfo} will not be supplied to the callback, - * and a {@linkplain OidcCallbackResult#getRefreshToken() refresh token} + * and a {@linkplain com.mongodb.MongoCredential.OidcCallbackResult#getRefreshToken() refresh token} * must not be returned by the callback. *

    - * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * If this is provided, {@link MongoCredential#ENVIRONMENT_KEY} * and {@link MongoCredential#OIDC_HUMAN_CALLBACK_KEY} * must not be provided. * * @see #createOidcCredential(String) - * @since 4.10 + * @since 5.1 */ public static final String OIDC_CALLBACK_KEY = "OIDC_CALLBACK"; /** + * Mechanism property key for the OIDC human callback. * This callback is invoked when the OIDC-based authenticator requests * a token from the identity provider (IDP) using the IDP information * from the MongoDB server. The type of the value must be * {@link OidcCallback}. *

    - * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * If this is provided, {@link MongoCredential#ENVIRONMENT_KEY} * and {@link MongoCredential#OIDC_CALLBACK_KEY} * must not be provided. * * @see #createOidcCredential(String) - * @since 4.10 + * @since 5.1 */ public static final String OIDC_HUMAN_CALLBACK_KEY = "OIDC_HUMAN_CALLBACK"; /** - * Mechanism key for a list of allowed hostnames or ip-addresses for MongoDB connections. Ports must be excluded. + * Mechanism property key for a list of allowed hostnames or ip-addresses for MongoDB connections. Ports must be excluded. * The hostnames may include a leading "*." wildcard, which allows for matching (potentially nested) subdomains. * When MONGODB-OIDC authentication is attempted against a hostname that does not match any of list of allowed hosts * the driver will raise an error. The type of the value must be {@code List}. * * @see MongoCredential#DEFAULT_ALLOWED_HOSTS * @see #createOidcCredential(String) - * @since 4.10 + * @since 5.1 */ public static final String ALLOWED_HOSTS_KEY = "ALLOWED_HOSTS"; @@ -248,11 +258,21 @@ public final class MongoCredential { * {@code "*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"} * * @see #createOidcCredential(String) - * @since 4.10 + * @since 5.1 */ public static final List DEFAULT_ALLOWED_HOSTS = Collections.unmodifiableList(Arrays.asList( "*.mongodb.net", "*.mongodb-qa.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")); + /** + * Mechanism property key for specifying he URI of the target resource (sometimes called the audience), + * used in some OIDC environments. + * + * @see MongoCredential#ENVIRONMENT_KEY + * @see #createOidcCredential(String) + * @since 5.1 + */ + public static final String TOKEN_RESOURCE_KEY = "TOKEN_RESOURCE"; + /** * Creates a MongoCredential instance with an unspecified mechanism. The client will negotiate the best mechanism based on the * version of the server that the client is authenticating to. @@ -406,9 +426,10 @@ public static MongoCredential createAwsCredential(@Nullable final String userNam * * @param userName the user name, which may be null. This is the OIDC principal name. * @return the credential - * @since 4.10 + * @since 5.1 * @see #withMechanismProperty(String, Object) - * @see #PROVIDER_NAME_KEY + * @see #ENVIRONMENT_KEY + * @see #TOKEN_RESOURCE_KEY * @see #OIDC_CALLBACK_KEY * @see #OIDC_HUMAN_CALLBACK_KEY * @see #ALLOWED_HOSTS_KEY @@ -463,6 +484,7 @@ public MongoCredential withMechanism(final AuthenticationMechanism mechanism) { if (mechanism == MONGODB_OIDC) { validateOidcCredentialConstruction(source, mechanismProperties); + validateCreateOidcCredential(password); } if (userName == null && !Arrays.asList(MONGODB_X509, MONGODB_AWS, MONGODB_OIDC).contains(mechanism)) { @@ -641,14 +663,16 @@ public String toString() { /** * The context for the {@link OidcCallback#onRequest(OidcCallbackContext) OIDC request callback}. + * + * @since 5.1 */ @Evolving public interface OidcCallbackContext { /** - * @return The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + * @return Convenience method to obtain the {@linkplain MongoCredential#getUserName() username}. */ @Nullable - IdpInfo getIdpInfo(); + String getUserName(); /** * @return The timeout that this callback must complete within. @@ -661,7 +685,17 @@ public interface OidcCallbackContext { int getVersion(); /** - * @return The OIDC Refresh token supplied by a prior callback invocation. + * @return The OIDC Identity Provider's configuration that can be used + * to acquire an Access Token, or null if not using a + * {@linkplain MongoCredential#OIDC_HUMAN_CALLBACK_KEY human callback.} + */ + @Nullable + IdpInfo getIdpInfo(); + + /** + * @return The OIDC Refresh token supplied by a prior callback invocation, + * or null if no token was supplied, or if not using a + * {@linkplain MongoCredential#OIDC_HUMAN_CALLBACK_KEY human callback.} */ @Nullable String getRefreshToken(); @@ -673,6 +707,8 @@ public interface OidcCallbackContext { *

    * It does not have to be thread-safe, unless it is provided to multiple * MongoClients. + * + * @since 5.1 */ public interface OidcCallback { /** @@ -684,6 +720,8 @@ public interface OidcCallback { /** * The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + * + * @since 5.1 */ @Evolving public interface IdpInfo { @@ -697,6 +735,7 @@ public interface IdpInfo { /** * @return Unique client ID for this OIDC client. */ + @Nullable String getClientId(); /** @@ -706,7 +745,9 @@ public interface IdpInfo { } /** - * The response produced by an OIDC Identity Provider. + * The OIDC credential information. + * + * @since 5.1 */ public static final class OidcCallbackResult { @@ -717,6 +758,15 @@ public static final class OidcCallbackResult { @Nullable private final String refreshToken; + + /** + * An access token that does not expire. + * @param accessToken The OIDC access token. + */ + public OidcCallbackResult(final String accessToken) { + this(accessToken, Duration.ZERO, null); + } + /** * @param accessToken The OIDC access token. * @param expiresIn Time until the access token expires. diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index 2a169f45c52..984de156f27 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -17,8 +17,6 @@ package com.mongodb.internal; import com.mongodb.MongoInterruptedException; -import com.mongodb.internal.async.AsyncRunnable; -import com.mongodb.internal.async.SingleResultCallback; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -38,23 +36,7 @@ public static void withLock(final Lock lock, final Runnable action) { }); } - public static void withLockAsync(final StampedLock lock, final AsyncRunnable runnable, - final SingleResultCallback callback) { - long stamp; - try { - stamp = lock.writeLockInterruptibly(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - callback.onResult(null, new MongoInterruptedException("Interrupted waiting for lock", e)); - return; - } - - runnable.thenAlwaysRunAndFinish(() -> { - lock.unlockWrite(stamp); - }, callback); - } - - public static void withLock(final StampedLock lock, final Runnable runnable) { + public static void withInterruptibleLock(final StampedLock lock, final Runnable runnable) throws MongoInterruptedException{ long stamp; try { stamp = lock.writeLockInterruptibly(); diff --git a/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java b/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java index 7c75e397d2a..2a48b8b6fc3 100644 --- a/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java +++ b/driver-core/src/main/com/mongodb/internal/authentication/AzureCredentialHelper.java @@ -18,10 +18,13 @@ import com.mongodb.MongoClientException; import com.mongodb.internal.ExpirableValue; +import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; import org.bson.json.JsonParseException; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; import java.time.Duration; import java.util.HashMap; import java.util.Map; @@ -55,33 +58,11 @@ public static BsonDocument obtainFromEnvironment() { if (cachedValue.isPresent()) { accessToken = cachedValue.get(); } else { - String endpoint = "http://" + "169.254.169.254:80" - + "/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https://vault.azure.net"; - - Map headers = new HashMap<>(); - headers.put("Metadata", "true"); - headers.put("Accept", "application/json"); - long startNanoTime = System.nanoTime(); - BsonDocument responseDocument; - try { - responseDocument = BsonDocument.parse(getHttpContents("GET", endpoint, headers)); - } catch (JsonParseException e) { - throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e); - } - - if (!responseDocument.isString(ACCESS_TOKEN_FIELD)) { - throw new MongoClientException(String.format( - "The %s field from Azure IMDS metadata response is missing or is not a string", ACCESS_TOKEN_FIELD)); - } - if (!responseDocument.isString(EXPIRES_IN_FIELD)) { - throw new MongoClientException(String.format( - "The %s field from Azure IMDS metadata response is missing or is not a string", EXPIRES_IN_FIELD)); - } - accessToken = responseDocument.getString(ACCESS_TOKEN_FIELD).getValue(); - int expiresInSeconds = Integer.parseInt(responseDocument.getString(EXPIRES_IN_FIELD).getValue()); - cachedAccessToken = ExpirableValue.expirable(accessToken, Duration.ofSeconds(expiresInSeconds).minus(Duration.ofMinutes(1)), - startNanoTime); + CredentialInfo response = fetchAzureCredentialInfo("https://vault.azure.net", null); + accessToken = response.getAccessToken(); + Duration duration = response.getExpiresIn().minus(Duration.ofMinutes(1)); + cachedAccessToken = ExpirableValue.expirable(accessToken, duration, startNanoTime); } } finally { CACHED_ACCESS_TOKEN_LOCK.unlock(); @@ -90,6 +71,44 @@ public static BsonDocument obtainFromEnvironment() { return new BsonDocument("accessToken", new BsonString(accessToken)); } + public static CredentialInfo fetchAzureCredentialInfo(final String resource, @Nullable final String clientId) { + String endpoint = "http://169.254.169.254:80" + + "/metadata/identity/oauth2/token?api-version=2018-02-01" + + "&resource=" + getEncoded(resource) + + (clientId == null ? "" : "&client_id=" + getEncoded(clientId)); + + Map headers = new HashMap<>(); + headers.put("Metadata", "true"); + headers.put("Accept", "application/json"); + + BsonDocument responseDocument; + try { + responseDocument = BsonDocument.parse(getHttpContents("GET", endpoint, headers)); + } catch (JsonParseException e) { + throw new MongoClientException("Exception parsing JSON from Azure IMDS metadata response.", e); + } + + if (!responseDocument.isString(ACCESS_TOKEN_FIELD)) { + throw new MongoClientException(String.format( + "The %s field from Azure IMDS metadata response is missing or is not a string", ACCESS_TOKEN_FIELD)); + } + if (!responseDocument.isString(EXPIRES_IN_FIELD)) { + throw new MongoClientException(String.format( + "The %s field from Azure IMDS metadata response is missing or is not a string", EXPIRES_IN_FIELD)); + } + String accessToken = responseDocument.getString(ACCESS_TOKEN_FIELD).getValue(); + int expiresInSeconds = Integer.parseInt(responseDocument.getString(EXPIRES_IN_FIELD).getValue()); + return new CredentialInfo(accessToken, Duration.ofSeconds(expiresInSeconds)); + } + + static String getEncoded(final String resource) { + try { + return URLEncoder.encode(resource, "UTF-8"); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } + private AzureCredentialHelper() { } } diff --git a/driver-core/src/main/com/mongodb/internal/authentication/CredentialInfo.java b/driver-core/src/main/com/mongodb/internal/authentication/CredentialInfo.java new file mode 100644 index 00000000000..8b1e601b13a --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/authentication/CredentialInfo.java @@ -0,0 +1,44 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.authentication; + +import java.time.Duration; + +/** + *

    This class is not part of the public API and may be removed or changed at any time

    + */ +public final class CredentialInfo { + private final String accessToken; + private final Duration expiresIn; + + /** + * @param expiresIn The meaning of {@linkplain Duration#isZero() zero-length} duration is the same as in + * {@link com.mongodb.MongoCredential.OidcCallbackResult#OidcCallbackResult(String, Duration)}. + */ + public CredentialInfo(final String accessToken, final Duration expiresIn) { + this.accessToken = accessToken; + this.expiresIn = expiresIn; + } + + public String getAccessToken() { + return accessToken; + } + + public Duration getExpiresIn() { + return expiresIn; + } +} diff --git a/driver-core/src/main/com/mongodb/internal/authentication/GcpCredentialHelper.java b/driver-core/src/main/com/mongodb/internal/authentication/GcpCredentialHelper.java index 92b3fdd6040..3f0272da48c 100644 --- a/driver-core/src/main/com/mongodb/internal/authentication/GcpCredentialHelper.java +++ b/driver-core/src/main/com/mongodb/internal/authentication/GcpCredentialHelper.java @@ -19,9 +19,11 @@ import com.mongodb.MongoClientException; import org.bson.BsonDocument; +import java.time.Duration; import java.util.HashMap; import java.util.Map; +import static com.mongodb.internal.authentication.AzureCredentialHelper.getEncoded; import static com.mongodb.internal.authentication.HttpHelper.getHttpContents; /** @@ -44,6 +46,17 @@ public static BsonDocument obtainFromEnvironment() { } } + public static CredentialInfo fetchGcpCredentialInfo(final String audience) { + String endpoint = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity?audience=" + + getEncoded(audience); + Map header = new HashMap<>(); + header.put("Metadata-Flavor", "Google"); + String response = getHttpContents("GET", endpoint, header); + return new CredentialInfo( + response, + Duration.ZERO); + } + private GcpCredentialHelper() { } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 6b2362cbc1f..af26abbf87f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -31,6 +31,9 @@ import com.mongodb.internal.Locks; import com.mongodb.internal.VisibleForTesting; import com.mongodb.internal.async.SingleResultCallback; +import com.mongodb.internal.authentication.AzureCredentialHelper; +import com.mongodb.internal.authentication.CredentialInfo; +import com.mongodb.internal.authentication.GcpCredentialHelper; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonString; @@ -51,12 +54,13 @@ import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; import static com.mongodb.MongoCredential.DEFAULT_ALLOWED_HOSTS; +import static com.mongodb.MongoCredential.ENVIRONMENT_KEY; import static com.mongodb.MongoCredential.IdpInfo; +import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; import static com.mongodb.MongoCredential.OIDC_HUMAN_CALLBACK_KEY; import static com.mongodb.MongoCredential.OidcCallback; import static com.mongodb.MongoCredential.OidcCallbackContext; -import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; -import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; +import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY; import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; @@ -69,11 +73,22 @@ */ public final class OidcAuthenticator extends SaslAuthenticator { - private static final List SUPPORTED_PROVIDERS = Arrays.asList("aws"); + private static final String TEST_ENVIRONMENT = "test"; + private static final String AZURE_ENVIRONMENT = "azure"; + private static final String GCP_ENVIRONMENT = "gcp"; + private static final List IMPLEMENTED_ENVIRONMENTS = Arrays.asList( + AZURE_ENVIRONMENT, GCP_ENVIRONMENT, TEST_ENVIRONMENT); + private static final List USER_SUPPORTED_ENVIRONMENTS = Arrays.asList( + AZURE_ENVIRONMENT, GCP_ENVIRONMENT); + private static final List REQUIRES_TOKEN_RESOURCE = Arrays.asList( + AZURE_ENVIRONMENT, GCP_ENVIRONMENT); + private static final List ALLOWS_USERNAME = Arrays.asList( + AZURE_ENVIRONMENT); private static final Duration CALLBACK_TIMEOUT = Duration.ofMinutes(5); - public static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + public static final String OIDC_TOKEN_FILE = "OIDC_TOKEN_FILE"; + private static final int CALLBACK_API_VERSION_NUMBER = 1; @Nullable @@ -113,10 +128,9 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { @Nullable public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) { try { - if (isAutomaticAuthentication()) { - return wrapInSpeculative(prepareAwsTokenFromFileAsJwt()); - } - String cachedAccessToken = getCachedAccessToken(); + String cachedAccessToken = getMongoCredentialWithCache() + .getOidcCacheEntry() + .getCachedAccessToken(); if (cachedAccessToken != null) { return wrapInSpeculative(prepareTokenAsJwt(cachedAccessToken)); } else { @@ -152,11 +166,8 @@ public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument resp speculativeAuthenticateResponse = response; } - private boolean isAutomaticAuthentication() { - return getOidcCallbackMechanismProperty(PROVIDER_NAME_KEY) == null; - } - private boolean isHumanCallback() { + // built-in providers (aws, azure...) are considered machine callbacks return getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY) != null; } @@ -167,10 +178,46 @@ private OidcCallback getOidcCallbackMechanismProperty(final String key) { .getMechanismProperty(key, null); } - @Nullable private OidcCallback getRequestCallback() { - OidcCallback machine = getOidcCallbackMechanismProperty(OIDC_CALLBACK_KEY); - return machine != null ? machine : getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY); + String environment = getMongoCredential().getMechanismProperty(ENVIRONMENT_KEY, null); + OidcCallback machine; + if (TEST_ENVIRONMENT.equals(environment)) { + machine = getTestCallback(); + } else if (AZURE_ENVIRONMENT.equals(environment)) { + machine = getAzureCallback(getMongoCredential()); + } else if (GCP_ENVIRONMENT.equals(environment)) { + machine = getGcpCallback(getMongoCredential()); + } else { + machine = getOidcCallbackMechanismProperty(OIDC_CALLBACK_KEY); + } + OidcCallback human = getOidcCallbackMechanismProperty(OIDC_HUMAN_CALLBACK_KEY); + return machine != null ? machine : assertNotNull(human); + } + + private static OidcCallback getTestCallback() { + return (context) -> { + String accessToken = readTokenFromFile(); + return new OidcCallbackResult(accessToken); + }; + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static OidcCallback getAzureCallback(final MongoCredential credential) { + return (context) -> { + String resource = assertNotNull(credential.getMechanismProperty(TOKEN_RESOURCE_KEY, null)); + String clientId = credential.getUserName(); + CredentialInfo response = AzureCredentialHelper.fetchAzureCredentialInfo(resource, clientId); + return new OidcCallbackResult(response.getAccessToken(), response.getExpiresIn()); + }; + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static OidcCallback getGcpCallback(final MongoCredential credential) { + return (context) -> { + String resource = assertNotNull(credential.getMechanismProperty(TOKEN_RESOURCE_KEY, null)); + CredentialInfo response = GcpCredentialHelper.fetchGcpCredentialInfo(resource); + return new OidcCallbackResult(response.getAccessToken(), response.getExpiresIn()); + }; } @Override @@ -239,17 +286,15 @@ private void authenticationLoopAsync(final InternalConnection connection, final } private byte[] evaluate(final byte[] challenge) { - if (isAutomaticAuthentication()) { - return prepareAwsTokenFromFileAsJwt(); - } byte[][] jwt = new byte[1][]; - Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { + Locks.withInterruptibleLock(getMongoCredentialWithCache().getOidcLock(), () -> { OidcCacheEntry oidcCacheEntry = getMongoCredentialWithCache().getOidcCacheEntry(); String cachedRefreshToken = oidcCacheEntry.getRefreshToken(); IdpInfo cachedIdpInfo = oidcCacheEntry.getIdpInfo(); String cachedAccessToken = validatedCachedAccessToken(); - OidcCallback requestCallback = assertNotNull(getRequestCallback()); + OidcCallback requestCallback = getRequestCallback(); boolean isHuman = isHumanCallback(); + String userName = getMongoCredentialWithCache().getCredential().getUserName(); if (cachedAccessToken != null) { fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; @@ -261,7 +306,7 @@ private byte[] evaluate(final byte[] challenge) { // Invoke Callback using cached Refresh Token fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( - CALLBACK_TIMEOUT, cachedIdpInfo, cachedRefreshToken)); + CALLBACK_TIMEOUT, cachedIdpInfo, cachedRefreshToken, userName)); jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); } else { // cache is empty @@ -270,7 +315,7 @@ private byte[] evaluate(final byte[] challenge) { // no principal request fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( - CALLBACK_TIMEOUT)); + CALLBACK_TIMEOUT, userName)); jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(null, result); if (result.getRefreshToken() != null) { throw new MongoConfigurationException( @@ -294,13 +339,13 @@ private byte[] evaluate(final byte[] challenge) { if (!alreadyTriedPrincipal && idpInfoNotPresent) { // request for idp info, only in the human workflow fallbackState = FallbackState.PHASE_3A_PRINCIPAL; - jwt[0] = prepareUsername(getMongoCredentialWithCache().getCredential().getUserName()); + jwt[0] = prepareUsername(userName); } else { IdpInfo idpInfo = toIdpInfo(challenge); // there is no cached refresh token fallbackState = FallbackState.PHASE_3B_CALLBACK_TOKEN; OidcCallbackResult result = requestCallback.onRequest(new OidcCallbackContextImpl( - CALLBACK_TIMEOUT, idpInfo, null)); + CALLBACK_TIMEOUT, idpInfo, null, userName)); jwt[0] = populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); } } @@ -316,7 +361,7 @@ private byte[] evaluate(final byte[] challenge) { private String validatedCachedAccessToken() { MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); - String cachedAccessToken = getCachedAccessToken(); + String cachedAccessToken = cacheEntry.getCachedAccessToken(); String invalidConnectionAccessToken = connectionLastAccessToken; if (cachedAccessToken != null) { @@ -335,7 +380,7 @@ private boolean clientIsComplete() { private boolean shouldRetryHandler() { boolean[] result = new boolean[1]; - Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { + Locks.withInterruptibleLock(getMongoCredentialWithCache().getOidcLock(), () -> { MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { @@ -360,13 +405,6 @@ private boolean shouldRetryHandler() { return result[0]; } - @Nullable - private String getCachedAccessToken() { - return getMongoCredentialWithCache() - .getOidcCacheEntry() - .getCachedAccessToken(); - } - static final class OidcCacheEntry { @Nullable private final String accessToken; @@ -443,18 +481,18 @@ public boolean isComplete() { } - private static String readAwsTokenFromFile() { - String path = System.getenv(AWS_WEB_IDENTITY_TOKEN_FILE); + private static String readTokenFromFile() { + String path = System.getenv(OIDC_TOKEN_FILE); if (path == null) { throw new MongoClientException( - format("Environment variable must be specified: %s", AWS_WEB_IDENTITY_TOKEN_FILE)); + format("Environment variable must be specified: %s", OIDC_TOKEN_FILE)); } try { return new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8); } catch (IOException e) { throw new MongoClientException(format( "Could not read file specified by environment variable: %s at path: %s", - AWS_WEB_IDENTITY_TOKEN_FILE, path), e); + OIDC_TOKEN_FILE, path), e); } } @@ -483,14 +521,13 @@ private IdpInfo toIdpInfo(final byte[] challenge) { validateAllowedHosts(getMongoCredential()); BsonDocument c = new RawBsonDocument(challenge); String issuer = c.getString("issuer").getValue(); - String clientId = c.getString("clientId").getValue(); + String clientId = !c.containsKey("clientId") ? null : c.getString("clientId").getValue(); return new IdpInfoImpl( issuer, clientId, getStringArray(c, "requestScopes")); } - @Nullable private static List getStringArray(final BsonDocument document, final String key) { if (!document.isArray(key)) { @@ -529,11 +566,6 @@ private byte[] prepareTokenAsJwt(final String accessToken) { return toJwtDocument(accessToken); } - private static byte[] prepareAwsTokenFromFileAsJwt() { - String accessToken = readAwsTokenFromFile(); - return toJwtDocument(accessToken); - } - private static byte[] toJwtDocument(final String accessToken) { return toBson(new BsonDocument().append("jwt", new BsonString(accessToken))); } @@ -553,10 +585,10 @@ public static void validateOidcCredentialConstruction( throw new IllegalArgumentException("source must be '$external'"); } - Object providerName = mechanismProperties.get(PROVIDER_NAME_KEY.toLowerCase()); - if (providerName != null) { - if (!(providerName instanceof String) || !SUPPORTED_PROVIDERS.contains(providerName)) { - throw new IllegalArgumentException(PROVIDER_NAME_KEY + " must be one of: " + SUPPORTED_PROVIDERS); + Object environmentName = mechanismProperties.get(ENVIRONMENT_KEY.toLowerCase()); + if (environmentName != null) { + if (!(environmentName instanceof String) || !IMPLEMENTED_ENVIRONMENTS.contains(environmentName)) { + throw new IllegalArgumentException(ENVIRONMENT_KEY + " must be one of: " + USER_SUPPORTED_ENVIRONMENTS); } } } @@ -571,13 +603,13 @@ public static void validateCreateOidcCredential(@Nullable final char[] password) @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) public static void validateBeforeUse(final MongoCredential credential) { String userName = credential.getUserName(); - Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); + Object environmentName = credential.getMechanismProperty(ENVIRONMENT_KEY, null); Object machineCallback = credential.getMechanismProperty(OIDC_CALLBACK_KEY, null); Object humanCallback = credential.getMechanismProperty(OIDC_HUMAN_CALLBACK_KEY, null); - if (providerName == null) { + if (environmentName == null) { // callback if (machineCallback == null && humanCallback == null) { - throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + throw new IllegalArgumentException("Either " + ENVIRONMENT_KEY + " or " + OIDC_CALLBACK_KEY + " or " + OIDC_HUMAN_CALLBACK_KEY + " must be specified"); @@ -588,20 +620,32 @@ public static void validateBeforeUse(final MongoCredential credential) { + " must not be specified"); } } else { - if (userName != null) { - throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + if (!(environmentName instanceof String)) { + throw new IllegalArgumentException(ENVIRONMENT_KEY + " must be a String"); + } + if (userName != null && !ALLOWS_USERNAME.contains(environmentName)) { + throw new IllegalArgumentException("user name must not be specified when " + ENVIRONMENT_KEY + " is specified"); } if (machineCallback != null) { - throw new IllegalArgumentException(OIDC_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + throw new IllegalArgumentException(OIDC_CALLBACK_KEY + " must not be specified when " + ENVIRONMENT_KEY + " is specified"); } if (humanCallback != null) { - throw new IllegalArgumentException(OIDC_HUMAN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + throw new IllegalArgumentException(OIDC_HUMAN_CALLBACK_KEY + " must not be specified when " + ENVIRONMENT_KEY + " is specified"); + } + String tokenResource = credential.getMechanismProperty(TOKEN_RESOURCE_KEY, null); + boolean hasTokenResourceProperty = tokenResource != null; + boolean tokenResourceSupported = REQUIRES_TOKEN_RESOURCE.contains(environmentName); + if (hasTokenResourceProperty != tokenResourceSupported) { + throw new IllegalArgumentException(TOKEN_RESOURCE_KEY + + " must be provided if and only if " + ENVIRONMENT_KEY + + " " + environmentName + " " + + " is one of: " + REQUIRES_TOKEN_RESOURCE + + ". " + TOKEN_RESOURCE_KEY + ": " + tokenResource); } } } } - @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) static class OidcCallbackContextImpl implements OidcCallbackContext { private final Duration timeout; @@ -609,20 +653,26 @@ static class OidcCallbackContextImpl implements OidcCallbackContext { private final IdpInfo idpInfo; @Nullable private final String refreshToken; + @Nullable + private final String userName; - OidcCallbackContextImpl(final Duration timeout) { + OidcCallbackContextImpl(final Duration timeout, @Nullable final String userName) { this.timeout = assertNotNull(timeout); this.idpInfo = null; this.refreshToken = null; + this.userName = userName; } - OidcCallbackContextImpl(final Duration timeout, final IdpInfo idpInfo, @Nullable final String refreshToken) { + OidcCallbackContextImpl(final Duration timeout, final IdpInfo idpInfo, + @Nullable final String refreshToken, @Nullable final String userName) { this.timeout = assertNotNull(timeout); this.idpInfo = assertNotNull(idpInfo); this.refreshToken = refreshToken; + this.userName = userName; } @Override + @Nullable public IdpInfo getIdpInfo() { return idpInfo; } @@ -638,20 +688,28 @@ public int getVersion() { } @Override + @Nullable public String getRefreshToken() { return refreshToken; } + + @Override + @Nullable + public String getUserName() { + return userName; + } } @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) static final class IdpInfoImpl implements IdpInfo { private final String issuer; + @Nullable private final String clientId; private final List requestScopes; - IdpInfoImpl(final String issuer, final String clientId, @Nullable final List requestScopes) { + IdpInfoImpl(final String issuer, @Nullable final String clientId, @Nullable final List requestScopes) { this.issuer = assertNotNull(issuer); - this.clientId = assertNotNull(clientId); + this.clientId = clientId; this.requestScopes = requestScopes == null ? Collections.emptyList() : Collections.unmodifiableList(requestScopes); @@ -663,6 +721,7 @@ public String getIssuer() { } @Override + @Nullable public String getClientId() { return clientId; } diff --git a/driver-core/src/test/resources/auth/legacy/connection-string.json b/driver-core/src/test/resources/auth/legacy/connection-string.json index f8521be9d19..072dd176dc8 100644 --- a/driver-core/src/test/resources/auth/legacy/connection-string.json +++ b/driver-core/src/test/resources/auth/legacy/connection-string.json @@ -446,8 +446,8 @@ } }, { - "description": "should recognise the mechanism with aws provider (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "description": "should recognise the mechanism with test environment (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:test", "valid": true, "credential": { "username": null, @@ -455,13 +455,13 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "PROVIDER_NAME": "aws" + "ENVIRONMENT": "test" } } }, { - "description": "should recognise the mechanism when auth source is explicitly specified and with provider (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws", + "description": "should recognise the mechanism when auth source is explicitly specified and with environment (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=ENVIRONMENT:test", "valid": true, "credential": { "username": null, @@ -469,30 +469,30 @@ "source": "$external", "mechanism": "MONGODB-OIDC", "mechanism_properties": { - "PROVIDER_NAME": "aws" + "ENVIRONMENT": "test" } } }, { "description": "should throw an exception if supplied a password (MONGODB-OIDC)", - "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:test", "valid": false, "credential": null }, { - "description": "should throw an exception if username is specified for aws (MONGODB-OIDC)", - "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:aws", + "description": "should throw an exception if username is specified for test (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&ENVIRONMENT:test", "valid": false, "credential": null }, { - "description": "should throw an exception if specified provider is not supported (MONGODB-OIDC)", - "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:invalid", + "description": "should throw an exception if specified environment is not supported (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:invalid", "valid": false, "credential": null }, { - "description": "should throw an exception if neither provider nor callbacks specified (MONGODB-OIDC)", + "description": "should throw an exception if neither environment nor callbacks specified (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", "valid": false, "credential": null @@ -502,6 +502,135 @@ "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", "valid": false, "credential": null + }, + { + "description": "should recognise the mechanism with azure provider (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:foo", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "foo" + } + } + }, + { + "description": "should accept a username with azure provider (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:foo", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "foo" + } + } + }, + { + "description": "should accept a url-encoded TOKEN_RESOURCE (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:mongodb%3A%2F%2Ftest-cluster", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "mongodb://test-cluster" + } + } + }, + { + "description": "should accept an un-encoded TOKEN_RESOURCE (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:mongodb://test-cluster", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "mongodb://test-cluster" + } + } + }, + { + "description": "should handle a complicated url-encoded TOKEN_RESOURCE (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:abc%2Cd%25ef%3Ag%26hi", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "abc,d%ef:g&hi" + } + } + }, + { + "description": "should url-encode a TOKEN_RESOURCE (MONGODB-OIDC)", + "uri": "mongodb://user@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:a$b", + "valid": true, + "credential": { + "username": "user", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "azure", + "TOKEN_RESOURCE": "a$b" + } + } + }, + { + "description": "should accept a username and throw an error for a password with azure provider (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure,TOKEN_RESOURCE:foo", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if no token audience is given for azure provider (MONGODB-OIDC)", + "uri": "mongodb://username@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:azure", + "valid": false, + "credential": null + }, + { + "description": "should recognise the mechanism with gcp provider (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:gcp,TOKEN_RESOURCE:foo", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "ENVIRONMENT": "gcp", + "TOKEN_RESOURCE": "foo" + } + } + }, + { + "description": "should throw an error for a username and password with gcp provider (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:gcp,TOKEN_RESOURCE:foo", + "valid": false, + "credential": null + }, + { + "description": "should throw an error if not TOKEN_RESOURCE with gcp provider (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ENVIRONMENT:gcp", + "valid": false, + "credential": null } ] } diff --git a/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json b/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json index 7287c2486f0..83065f492ae 100644 --- a/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json +++ b/driver-core/src/test/resources/unified-test-format/auth/mongodb-oidc-no-retry.json @@ -52,9 +52,7 @@ { "collectionName": "collName", "databaseName": "test", - "documents": [ - - ] + "documents": [] } ], "tests": [ @@ -65,12 +63,9 @@ "name": "find", "object": "collection0", "arguments": { - "filter": { - } + "filter": {} }, - "expectResult": [ - - ] + "expectResult": [] } ], "expectEvents": [ @@ -81,8 +76,7 @@ "commandStartedEvent": { "command": { "find": "collName", - "filter": { - } + "filter": {} } } }, @@ -161,12 +155,9 @@ "name": "find", "object": "collection0", "arguments": { - "filter": { - } + "filter": {} }, - "expectResult": [ - - ] + "expectResult": [] } ], "expectEvents": [ @@ -177,8 +168,7 @@ "commandStartedEvent": { "command": { "find": "collName", - "filter": { - } + "filter": {} } } }, @@ -191,8 +181,7 @@ "commandStartedEvent": { "command": { "find": "collName", - "filter": { - } + "filter": {} } } }, @@ -324,12 +313,14 @@ "client": "failPointClient", "failPoint": { "configureFailPoint": "failCommand", - "mode": "alwaysOn", + "mode": { + "times": 1 + }, "data": { "failCommands": [ "saslStart" ], - "errorCode": 20 + "errorCode": 18 } } } @@ -399,12 +390,14 @@ "client": "failPointClient", "failPoint": { "configureFailPoint": "failCommand", - "mode": "alwaysOn", + "mode": { + "times": 1 + }, "data": { "failCommands": [ "saslStart" ], - "errorCode": 20 + "errorCode": 18 } } } @@ -419,7 +412,7 @@ } }, "expectError": { - "errorCode": 20 + "errorCode": 18 } } ] diff --git a/driver-core/src/test/unit/com/mongodb/ConnectionStringSpecification.groovy b/driver-core/src/test/unit/com/mongodb/ConnectionStringSpecification.groovy index e8731439a84..d56aa8a9c7c 100644 --- a/driver-core/src/test/unit/com/mongodb/ConnectionStringSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/ConnectionStringSpecification.groovy @@ -601,7 +601,7 @@ class ConnectionStringSpecification extends Specification { new ConnectionString('mongodb://jeff@localhost/?' + 'authMechanism=GSSAPI' + '&authMechanismProperties=' + - 'SERVICE_NAME:foo:bar') + 'SERVICE_NAMEbar') // missing = then: thrown(IllegalArgumentException) diff --git a/driver-core/src/test/unit/com/mongodb/ConnectionStringUnitTest.java b/driver-core/src/test/unit/com/mongodb/ConnectionStringUnitTest.java index d2e41ebeafd..6a8d9ff4fc3 100644 --- a/driver-core/src/test/unit/com/mongodb/ConnectionStringUnitTest.java +++ b/driver-core/src/test/unit/com/mongodb/ConnectionStringUnitTest.java @@ -15,11 +15,16 @@ */ package com.mongodb; +import com.mongodb.assertions.Assertions; import com.mongodb.connection.ServerMonitoringMode; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; + import static org.junit.jupiter.api.Assertions.assertAll; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -34,6 +39,48 @@ void defaults() { assertAll(() -> assertNull(connectionStringDefault.getServerMonitoringMode())); } + @Test + public void mustDecodeOidcIndividually() { + String string = "abc,d!@#$%^&*;ef:ghi"; + // encoded tags will fail parsing with an "invalid read preference tag" + // error if decoding is skipped. + String encodedTags = encode("dc:ny,rack:1"); + ConnectionString cs = new ConnectionString( + "mongodb://localhost/?readPreference=primaryPreferred&readPreferenceTags=" + encodedTags + + "&authMechanism=MONGODB-OIDC&authMechanismProperties=" + + "ENVIRONMENT:azure,TOKEN_RESOURCE:" + encode(string)); + MongoCredential credential = Assertions.assertNotNull(cs.getCredential()); + assertEquals(string, credential.getMechanismProperty("TOKEN_RESOURCE", null)); + } + + @Test + public void mustDecodeNonOidcAsWhole() { + // this string allows us to check if there is no double decoding + String rawValue = encode("ot her"); + assertAll(() -> { + // even though only one part has been encoded by the user, the whole option value (pre-split) must be decoded + ConnectionString cs = new ConnectionString( + "mongodb://foo:bar@example.com/?authMechanism=GSSAPI&authMechanismProperties=" + + "SERVICE_NAME:" + encode(rawValue) + ",CANONICALIZE_HOST_NAME:true&authSource=$external"); + MongoCredential credential = Assertions.assertNotNull(cs.getCredential()); + assertEquals(rawValue, credential.getMechanismProperty("SERVICE_NAME", null)); + }, () -> { + ConnectionString cs = new ConnectionString( + "mongodb://foo:bar@example.com/?authMechanism=GSSAPI&authMechanismProperties=" + + encode("SERVICE_NAME:" + rawValue + ",CANONICALIZE_HOST_NAME:true&authSource=$external")); + MongoCredential credential = Assertions.assertNotNull(cs.getCredential()); + assertEquals(rawValue, credential.getMechanismProperty("SERVICE_NAME", null)); + }); + } + + private static String encode(final String string) { + try { + return URLEncoder.encode(string, StandardCharsets.UTF_8.name()); + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } + @ParameterizedTest @ValueSource(strings = {DEFAULT_OPTIONS + "serverMonitoringMode=stream"}) void equalAndHashCode(final String connectionString) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java index 7e83f802279..76e49d68cdb 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java @@ -16,7 +16,6 @@ package com.mongodb.client.unified; -import com.mongodb.AuthenticationMechanism; import com.mongodb.ClientEncryptionSettings; import com.mongodb.ClientSessionOptions; import com.mongodb.MongoClientSettings; @@ -26,14 +25,8 @@ import com.mongodb.ReadPreference; import com.mongodb.ServerApi; import com.mongodb.ServerApiVersion; -import com.mongodb.internal.connection.OidcAuthenticator; -import com.mongodb.event.TestServerMonitorListener; -import com.mongodb.internal.connection.ServerMonitoringModeUtil; -import com.mongodb.internal.connection.TestClusterListener; -import com.mongodb.logging.TestLoggingInterceptor; import com.mongodb.TransactionOptions; import com.mongodb.WriteConcern; -import com.mongodb.assertions.Assertions; import com.mongodb.client.ClientSession; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoCollection; @@ -62,11 +55,15 @@ import com.mongodb.event.ConnectionPoolListener; import com.mongodb.event.ConnectionPoolReadyEvent; import com.mongodb.event.ConnectionReadyEvent; +import com.mongodb.event.TestServerMonitorListener; +import com.mongodb.internal.connection.ServerMonitoringModeUtil; +import com.mongodb.internal.connection.TestClusterListener; import com.mongodb.internal.connection.TestCommandListener; import com.mongodb.internal.connection.TestConnectionPoolListener; import com.mongodb.internal.connection.TestServerListener; import com.mongodb.internal.logging.LogMessage; import com.mongodb.lang.NonNull; +import com.mongodb.logging.TestLoggingInterceptor; import org.bson.BsonArray; import org.bson.BsonBoolean; import org.bson.BsonDocument; @@ -76,12 +73,6 @@ import org.bson.BsonString; import org.bson.BsonValue; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; @@ -96,9 +87,12 @@ import java.util.function.Function; import java.util.stream.Collectors; +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; import static com.mongodb.ClusterFixture.getMultiMongosConnectionString; import static com.mongodb.ClusterFixture.isLoadBalanced; import static com.mongodb.ClusterFixture.isSharded; +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder; import static com.mongodb.client.Fixture.getMultiMongosMongoClientSettingsBuilder; import static com.mongodb.client.unified.EventMatcher.getReasonString; @@ -529,29 +523,39 @@ private void initClient(final BsonDocument entity, final String id, ServerMonitoringModeUtil.fromString(value.asString().getValue()))); break; case "authMechanism": - if (value.equals(new BsonString(AuthenticationMechanism.MONGODB_OIDC.getMechanismName()))) { - clientSettingsBuilder.credential(MongoCredential.createOidcCredential(null)); + if (value.equals(new BsonString(MONGODB_OIDC.getMechanismName()))) { + // authMechanismProperties depends on authMechanism + BsonDocument authMechanismProperties = entity + .getDocument("uriOptions") + .getDocument("authMechanismProperties"); + boolean hasPlaceholder = authMechanismProperties.equals( + new BsonDocument("$$placeholder", new BsonInt32(1))); + if (!hasPlaceholder) { + throw new UnsupportedOperationException( + "Unsupported authMechanismProperties for authMechanism: " + value); + } + + String env = assertNotNull(getenv("OIDC_ENV")); + MongoCredential oidcCredential = MongoCredential + .createOidcCredential(null) + .withMechanismProperty("ENVIRONMENT", env); + if (env.equals("azure")) { + oidcCredential = oidcCredential.withMechanismProperty( + MongoCredential.TOKEN_RESOURCE_KEY, getenv("AZUREOIDC_RESOURCE")); + } else if (env.equals("gcp")) { + oidcCredential = oidcCredential.withMechanismProperty( + MongoCredential.TOKEN_RESOURCE_KEY, getenv("GCPOIDC_RESOURCE")); + } + clientSettingsBuilder.credential(oidcCredential); break; } throw new UnsupportedOperationException("Unsupported authMechanism: " + value); case "authMechanismProperties": - MongoCredential credential = clientSettingsBuilder.build().getCredential(); - boolean isOidc = credential != null - && credential.getAuthenticationMechanism() == AuthenticationMechanism.MONGODB_OIDC; - boolean hasPlaceholder = value.equals(new BsonDocument("$$placeholder", new BsonInt32(1))); - if (isOidc && hasPlaceholder) { - clientSettingsBuilder.credential(credential.withMechanismProperty( - MongoCredential.OIDC_CALLBACK_KEY, - (MongoCredential.OidcCallback) context -> { - Path path = Paths.get(getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE)); - String accessToken; - try { - accessToken = new String(Files.readAllBytes(path), StandardCharsets.UTF_8); - } catch (IOException e) { - throw new RuntimeException(e); - } - return new MongoCredential.OidcCallbackResult(accessToken, Duration.ZERO); - })); + // authMechanismProperties are handled as part of authMechanism, above + BsonValue authMechanism = entity + .getDocument("uriOptions") + .get("authMechanism"); + if (authMechanism.equals(new BsonString(MONGODB_OIDC.getMechanismName()))) { break; } throw new UnsupportedOperationException("Failure to apply authMechanismProperties: " + value); @@ -718,7 +722,7 @@ private void initClientEncryption(final BsonDocument entity, final String id, } } - putEntity(id, clientEncryptionSupplier.apply(Assertions.notNull("mongoClient", mongoClient), builder.build()), clientEncryptions); + putEntity(id, clientEncryptionSupplier.apply(notNull("mongoClient", mongoClient), builder.build()), clientEncryptions); } private TransactionOptions getTransactionOptions(final BsonDocument options) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java b/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java index aa7a3f80a53..60553c73f96 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/RunOnRequirementsMatcher.java @@ -69,7 +69,10 @@ public static boolean runOnRequirementsMet(final BsonArray runOnRequirements, fi } break; case "auth": - if (curRequirement.getValue().asBoolean().getValue() == (clientSettings.getCredential() == null)) { + boolean authRequired = curRequirement.getValue().asBoolean().getValue(); + boolean credentialPresent = clientSettings.getCredential() != null; + + if (authRequired != credentialPresent) { requirementMet = false; break requirementLoop; } diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index b5a87a51cef..9915f6a6a34 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -23,6 +23,8 @@ import com.mongodb.MongoCredential; import com.mongodb.MongoSecurityException; import com.mongodb.MongoSocketException; +import com.mongodb.assertions.Assertions; +import com.mongodb.client.Fixture; import com.mongodb.client.MongoClient; import com.mongodb.client.MongoClients; import com.mongodb.client.TestListener; @@ -33,7 +35,7 @@ import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; -import org.jetbrains.annotations.NotNull; +import org.bson.Document; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -60,62 +62,70 @@ import java.util.stream.Collectors; import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.ENVIRONMENT_KEY; +import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; import static com.mongodb.MongoCredential.OIDC_HUMAN_CALLBACK_KEY; -import static com.mongodb.MongoCredential.OidcCallbackResult; import static com.mongodb.MongoCredential.OidcCallback; import static com.mongodb.MongoCredential.OidcCallbackContext; -import static com.mongodb.MongoCredential.OIDC_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OidcCallbackResult; +import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY; +import static com.mongodb.assertions.Assertions.assertNotNull; import static java.lang.System.getenv; import static java.util.Arrays.asList; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; import static util.ThreadTestHelpers.executeAll; /** * See - * Prose Tests. + * Prose Tests. */ public class OidcAuthenticationProseTests { + private String appName; + public static boolean oidcTestsEnabled() { return Boolean.parseBoolean(getenv().get("OIDC_TESTS_ENABLED")); } - private String appName; + private void assumeTestEnvironment() { + assumeTrue(getenv("OIDC_TOKEN_DIR") != null); + } protected static String getOidcUri() { - ConnectionString cs = new ConnectionString(getenv("OIDC_ATLAS_URI_SINGLE")); - // remove any username and password - return "mongodb+srv://" + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; + return getenv("MONGODB_URI_SINGLE"); + } + + private static String getOidcUriMulti() { + return getenv("MONGODB_URI_MULTI"); } - protected static String getOidcUri(final String username) { - ConnectionString cs = new ConnectionString(getenv("OIDC_ATLAS_URI_SINGLE")); - // set username - return "mongodb+srv://" + username + "@" + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; + private static String getOidcEnv() { + return getenv("OIDC_ENV"); } - protected static String getOidcUriMulti(@Nullable final String username) { - ConnectionString cs = new ConnectionString(getenv("OIDC_ATLAS_URI_MULTI")); - // set username - String userPart = username == null ? "" : username + "@"; - return "mongodb+srv://" + userPart + cs.getHosts().get(0) + "/?authMechanism=MONGODB-OIDC"; + private static void assumeAzure() { + assumeTrue(getOidcEnv().equals("azure")); } - private static String getAwsOidcUri() { - return getOidcUri() + "&authMechanismProperties=PROVIDER_NAME:aws"; + @Nullable + private static String getUserWithDomain(@Nullable final String user) { + return user == null ? null : user + "@" + getenv("OIDC_DOMAIN"); } - @NotNull private static String oidcTokenDirectory() { - return getenv("OIDC_TOKEN_DIR"); + String dir = getenv("OIDC_TOKEN_DIR"); + if (!dir.endsWith("/")) { + dir = dir + "/"; + } + return dir; } - private static String getAwsTokenFilePath() { - return getenv(OidcAuthenticator.AWS_WEB_IDENTITY_TOKEN_FILE); + private static String getTestTokenFilePath() { + return getenv(OidcAuthenticator.OIDC_TOKEN_FILE); } protected MongoClient createMongoClient(final MongoClientSettings settings) { @@ -137,17 +147,17 @@ public void afterEach() { @Test public void test1p1CallbackIsCalledDuringAuth() { // #. Create a ``MongoClient`` configured with an OIDC callback... - TestCallback onRequest = createCallback(); - MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); + TestCallback callback = createCallback(); + MongoClientSettings clientSettings = createSettings(callback); // #. Perform a find operation that succeeds performFind(clientSettings); - assertEquals(1, onRequest.invocations.get()); + assertEquals(1, callback.invocations.get()); } @Test public void test1p2CallbackCalledOnceForMultipleConnections() { - TestCallback onRequest = createCallback(); - MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); + TestCallback callback = createCallback(); + MongoClientSettings clientSettings = createSettings(callback); try (MongoClient mongoClient = createMongoClient(clientSettings)) { List threads = new ArrayList<>(); for (int i = 0; i < 10; i++) { @@ -164,77 +174,71 @@ public void test1p2CallbackCalledOnceForMultipleConnections() { } } } - assertEquals(1, onRequest.invocations.get()); + assertEquals(1, callback.invocations.get()); } @Test public void test2p1ValidCallbackInputs() { - String connectionString = getOidcUri(); Duration expectedSeconds = Duration.ofMinutes(5); - TestCallback onRequest = createCallback(); + TestCallback callback1 = createCallback(); // #. Verify that the request callback was called with the appropriate // inputs, including the timeout parameter if possible. - OidcCallback onRequest2 = (context) -> { + OidcCallback callback2 = (context) -> { assertEquals(expectedSeconds, context.getTimeout()); - return onRequest.onRequest(context); + return callback1.onRequest(context); }; - MongoClientSettings clientSettings = createSettings(connectionString, onRequest2); + MongoClientSettings clientSettings = createSettings(callback2); try (MongoClient mongoClient = createMongoClient(clientSettings)) { performFind(mongoClient); // callback was called - assertEquals(1, onRequest.getInvocations()); + assertEquals(1, callback1.getInvocations()); } } @Test public void test2p2RequestCallbackReturnsNull() { //noinspection ConstantConditions - OidcCallback onRequest = (context) -> null; - MongoClientSettings settings = this.createSettings(getOidcUri(), onRequest, null); - performFind(settings, MongoConfigurationException.class, "Result of callback must not be null"); + OidcCallback callback = (context) -> null; + MongoClientSettings clientSettings = this.createSettings(callback); + assertFindFails(clientSettings, MongoConfigurationException.class, + "Result of callback must not be null"); } @Test public void test2p3CallbackReturnsMissingData() { // #. Create a client with a request callback that returns data not // conforming to the OIDCRequestTokenResult with missing field(s). - OidcCallback onRequest = (context) -> { + OidcCallback callback = (context) -> { //noinspection ConstantConditions - return new OidcCallbackResult(null, Duration.ZERO); + return new OidcCallbackResult(null); }; // we ensure that the error is propagated - MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); + MongoClientSettings clientSettings = createSettings(callback); try (MongoClient mongoClient = createMongoClient(clientSettings)) { - try { - performFind(mongoClient); - fail(); - } catch (Exception e) { - assertCause(IllegalArgumentException.class, "accessToken can not be null", e); - } + assertCause(IllegalArgumentException.class, + "accessToken can not be null", + () -> performFind(mongoClient)); } } @Test public void test2p4InvalidClientConfigurationWithCallback() { - String awsOidcUri = getAwsOidcUri(); + String uri = getOidcUri() + "&authMechanismProperties=ENVIRONMENT:" + getOidcEnv(); MongoClientSettings settings = createSettings( - awsOidcUri, createCallback(), null); - try { - performFind(settings); - fail(); - } catch (Exception e) { - assertCause(IllegalArgumentException.class, - "OIDC_CALLBACK must not be specified when PROVIDER_NAME is specified", e); - } + uri, createCallback(), null, OIDC_CALLBACK_KEY); + assertCause(IllegalArgumentException.class, + "OIDC_CALLBACK must not be specified when ENVIRONMENT is specified", + () -> performFind(settings)); } @Test public void test3p1AuthFailsWithCachedToken() throws ExecutionException, InterruptedException, NoSuchFieldException, IllegalAccessException { - TestCallback onRequestWrapped = createCallback(); + TestCallback callbackWrapped = createCallback(); + // reference to the token to poison CompletableFuture poisonToken = new CompletableFuture<>(); - OidcCallback onRequest = (context) -> { - OidcCallbackResult result = onRequestWrapped.onRequest(context); + OidcCallback callback = (context) -> { + OidcCallbackResult result = callbackWrapped.onRequest(context); String accessToken = result.getAccessToken(); if (!poisonToken.isDone()) { poisonToken.complete(accessToken); @@ -242,11 +246,11 @@ public void test3p1AuthFailsWithCachedToken() throws ExecutionException, Interru return result; }; - MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest, null); + MongoClientSettings clientSettings = createSettings(callback); try (MongoClient mongoClient = createMongoClient(clientSettings)) { // populate cache performFind(mongoClient); - assertEquals(1, onRequestWrapped.invocations.get()); + assertEquals(1, callbackWrapped.invocations.get()); // Poison the *Client Cache* with an invalid access token. // uses reflection String poisonString = poisonToken.get(); @@ -256,50 +260,161 @@ public void test3p1AuthFailsWithCachedToken() throws ExecutionException, Interru poisonChars[0] = '~'; poisonChars[1] = '~'; - assertEquals(1, onRequestWrapped.invocations.get()); + assertEquals(1, callbackWrapped.invocations.get()); // cause another connection to be opened - delayNextFind(); // cause both callbacks to be called + delayNextFind(); executeAll(2, () -> performFind(mongoClient)); } - assertEquals(2, onRequestWrapped.invocations.get()); + assertEquals(2, callbackWrapped.invocations.get()); } @Test public void test3p2AuthFailsWithoutCachedToken() { - MongoClientSettings clientSettings = createSettings(getOidcUri(), - (x) -> new OidcCallbackResult("invalid_token", Duration.ZERO), null); + OidcCallback callback = + (x) -> new OidcCallbackResult("invalid_token"); + MongoClientSettings clientSettings = createSettings(callback); try (MongoClient mongoClient = createMongoClient(clientSettings)) { - try { - performFind(mongoClient); - fail(); - } catch (Exception e) { - assertCause(MongoCommandException.class, - "Command failed with error 18 (AuthenticationFailed):", e); - } + assertCause(MongoCommandException.class, + "Command failed with error 18 (AuthenticationFailed):", + () -> performFind(mongoClient)); + } + } + + @Test + public void test3p3UnexpectedErrorDoesNotClearCache() { + assumeTestEnvironment(); + + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + + TestCallback callback = createCallback(); + MongoClientSettings clientSettings = createSettings(getOidcUri(), callback, commandListener); + + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + failCommand(20, 1, "saslStart"); + assertCause(MongoCommandException.class, + "Command failed with error 20", + () -> performFind(mongoClient)); + + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "saslStart started", + "saslStart failed" + ), listener.getEventStrings()); + + assertEquals(1, callback.getInvocations()); + performFind(mongoClient); + assertEquals(1, callback.getInvocations()); } } @Test public void test4p1Reauthentication() { - TestCallback onRequest = createCallback(); - MongoClientSettings clientSettings = createSettings(getOidcUri(), onRequest); + TestCallback callback = createCallback(); + MongoClientSettings clientSettings = createSettings(callback); try (MongoClient mongoClient = createMongoClient(clientSettings)) { failCommand(391, 1, "find"); // #. Perform a find operation that succeeds. performFind(mongoClient); } - assertEquals(2, onRequest.invocations.get()); + assertEquals(2, callback.invocations.get()); + } + + @Test + public void test4p2ReadCommandsFailIfReauthenticationFails() { + // Create a `MongoClient` whose OIDC callback returns one good token + // and then bad tokens after the first call. + TestCallback wrappedCallback = createCallback(); + OidcCallback callback = (context) -> { + OidcCallbackResult result1 = wrappedCallback.callback(context); + return new OidcCallbackResult(wrappedCallback.getInvocations() > 1 ? "bad" : result1.getAccessToken()); + }; + MongoClientSettings clientSettings = createSettings(callback); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + failCommand(391, 1, "find"); + assertCause(MongoCommandException.class, + "Command failed with error 18", + () -> performFind(mongoClient)); + } + assertEquals(2, wrappedCallback.invocations.get()); + } + + @Test + public void test4p3WriteCommandsFailIfReauthenticationFails() { + // Create a `MongoClient` whose OIDC callback returns one good token + // and then bad tokens after the first call. + TestCallback wrappedCallback = createCallback(); + OidcCallback callback = (context) -> { + OidcCallbackResult result1 = wrappedCallback.callback(context); + return new OidcCallbackResult( + wrappedCallback.getInvocations() > 1 ? "bad" : result1.getAccessToken()); + }; + MongoClientSettings clientSettings = createSettings(callback); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performInsert(mongoClient); + failCommand(391, 1, "insert"); + assertCause(MongoCommandException.class, + "Command failed with error 18", + () -> performInsert(mongoClient)); + } + assertEquals(2, wrappedCallback.invocations.get()); + } + + private static void performInsert(final MongoClient mongoClient) { + mongoClient + .getDatabase("test") + .getCollection("test") + .insertOne(Document.parse("{ x: 1 }")); + } + + @Test + public void test5p1AzureSucceedsWithNoUsername() { + assumeAzure(); + String oidcUri = getOidcUri(); + MongoClientSettings clientSettings = createSettings(oidcUri, createCallback(), null); + // Create an OIDC configured client with `ENVIRONMENT:azure` and a valid + // `TOKEN_RESOURCE` and no username. + MongoCredential credential = Assertions.assertNotNull(clientSettings.getCredential()); + assertNotNull(credential.getMechanismProperty(TOKEN_RESOURCE_KEY, null)); + assertNull(credential.getUserName()); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // Perform a `find` operation that succeeds. + performFind(mongoClient); + } + } + + @Test + public void test5p2AzureFailsWithBadUsername() { + assumeAzure(); + String oidcUri = getOidcUri(); + ConnectionString cs = new ConnectionString(oidcUri); + MongoCredential oldCredential = Assertions.assertNotNull(cs.getCredential()); + String tokenResource = oldCredential.getMechanismProperty(TOKEN_RESOURCE_KEY, null); + assertNotNull(tokenResource); + MongoCredential cred = MongoCredential.createOidcCredential("bad") + .withMechanismProperty(ENVIRONMENT_KEY, "azure") + .withMechanismProperty(TOKEN_RESOURCE_KEY, tokenResource); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applicationName(appName) + .retryReads(false) + .applyConnectionString(cs) + .credential(cred); + MongoClientSettings clientSettings = builder.build(); + // the failure is external to the driver + assertFindFails(clientSettings, IOException.class, "400 Bad Request"); } // Tests for human authentication ("testh", to preserve ordering) @Test public void testh1p1SinglePrincipalImplicitUsername() { + assumeTestEnvironment(); // #. Create default OIDC client with authMechanism=MONGODB-OIDC. - String oidcUri = getOidcUri(); TestCallback callback = createHumanCallback(); - MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + MongoClientSettings clientSettings = createHumanSettings(callback, null); // #. Perform a find operation that succeeds performFind(clientSettings); assertEquals(1, callback.invocations.get()); @@ -307,67 +422,61 @@ public void testh1p1SinglePrincipalImplicitUsername() { @Test public void testh1p2SinglePrincipalExplicitUsername() { + assumeTestEnvironment(); // #. Create a client with MONGODB_URI_SINGLE, a username of test_user1, // authMechanism=MONGODB-OIDC, and the OIDC human callback. - String oidcUri = getOidcUri("test_user1"); TestCallback callback = createHumanCallback(); - MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + MongoClientSettings clientSettings = createSettingsHuman(getUserWithDomain("test_user1"), callback, getOidcUri()); // #. Perform a find operation that succeeds performFind(clientSettings); } @Test public void testh1p3MultiplePrincipalUser1() { + assumeTestEnvironment(); // #. Create a client with MONGODB_URI_MULTI, a username of test_user1, // authMechanism=MONGODB-OIDC, and the OIDC human callback. - String oidcUri = getOidcUriMulti("test_user1"); - TestCallback callback = createHumanCallback(); - MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); + MongoClientSettings clientSettings = createSettingsMulti(getUserWithDomain("test_user1"), createHumanCallback()); // #. Perform a find operation that succeeds performFind(clientSettings); } @Test public void testh1p4MultiplePrincipalUser2() { + assumeTestEnvironment(); //- Create a human callback that reads in the generated ``test_user2`` token file. //- Create a client with ``MONGODB_URI_MULTI``, a username of ``test_user2``, // ``authMechanism=MONGODB-OIDC``, and the OIDC human callback. - String oidcUri = getOidcUriMulti("test_user2"); - TestCallback callback = createHumanCallback() - .setPathSupplier(() -> tokenQueue("test_user2").remove()); - MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); - // #. Perform a find operation that succeeds + MongoClientSettings clientSettings = createSettingsMulti(getUserWithDomain("test_user2"), createHumanCallback() + .setPathSupplier(() -> tokenQueue("test_user2").remove())); performFind(clientSettings); } @Test public void testh1p5MultiplePrincipalNoUser() { - //- Create a client with ``MONGODB_URI_MULTI``, no username, - // ``authMechanism=MONGODB-OIDC``, and the OIDC human callback. - String oidcUri = getOidcUriMulti(null); - TestCallback callback = createHumanCallback(); - MongoClientSettings clientSettings = createHumanSettings(oidcUri, callback, null); - // #. Perform a find operation that succeeds - performFind(clientSettings, MongoCommandException.class, "Authentication failed"); + assumeTestEnvironment(); + // Create an OIDC configured client with `MONGODB_URI_MULTI` and no username. + MongoClientSettings clientSettings = createSettingsMulti(null, createHumanCallback()); + // Assert that a `find` operation fails. + assertFindFails(clientSettings, MongoCommandException.class, "Authentication failed"); } @Test public void testh1p6AllowedHostsBlocked() { + assumeTestEnvironment(); //- Create a default OIDC client, with an ``ALLOWED_HOSTS`` that is an empty list. //- Assert that a ``find`` operation fails with a client-side error. - MongoClientSettings settings1 = createSettings( - getOidcUri(), + MongoClientSettings clientSettings1 = createSettings(getOidcUri(), createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Collections.emptyList()); - performFind(settings1, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); + assertFindFails(clientSettings1, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); //- Create a client that uses the URL // ``mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com``, a // human callback, and an ``ALLOWED_HOSTS`` that contains ``["example.com"]``. //- Assert that a ``find`` operation fails with a client-side error. - MongoClientSettings settings2 = createSettings( - getOidcUri() + "&ignored=example.com", + MongoClientSettings clientSettings2 = createSettings(getOidcUri() + "&ignored=example.com", createHumanCallback(), null, OIDC_HUMAN_CALLBACK_KEY, Arrays.asList("example.com")); - performFind(settings2, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); + assertFindFails(clientSettings2, MongoSecurityException.class, "not permitted by ALLOWED_HOSTS"); } // Not a prose test @@ -379,68 +488,119 @@ public void testAllowedHostsDisallowedInConnectionString() { () -> new ConnectionString(string)); } + @Test + public void testh1p7AllowedHostsInConnectionStringIgnored() { + // example.com changed to localhost, because resolveAdditionalQueryParametersFromTxtRecords + // fails with "Failed looking up TXT record for host example.com" + String string = "mongodb+srv://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=ALLOWED_HOSTS:%5B%22localhost%22%5D"; + assertCause(IllegalArgumentException.class, + "connection string contains disallowed mechanism properties", + () -> new ConnectionString(string)); + } + + @Test + public void testh1p8MachineIdpWithHumanCallback() { + assumeTrue(getenv("OIDC_IS_LOCAL") != null); + + TestCallback callback = createHumanCallback() + .setPathSupplier(() -> oidcTokenDirectory() + "test_machine"); + MongoClientSettings clientSettings = createSettingsHuman( + "test_machine", callback, getOidcUri()); + performFind(clientSettings); + } + @Test public void testh2p1ValidCallbackInputs() { - TestCallback onRequest = createHumanCallback(); - OidcCallback onRequest2 = (context) -> { - assertTrue(context.getIdpInfo().getClientId().startsWith("0oad")); - assertTrue(context.getIdpInfo().getIssuer().endsWith("mock-identity-config-oidc")); - assertEquals(Arrays.asList("fizz", "buzz"), context.getIdpInfo().getRequestScopes()); + assumeTestEnvironment(); + TestCallback callback1 = createHumanCallback(); + OidcCallback callback2 = (context) -> { + MongoCredential.IdpInfo idpInfo = assertNotNull(context.getIdpInfo()); + assertTrue(assertNotNull(idpInfo.getClientId()).startsWith("0oad")); + assertTrue(idpInfo.getIssuer().endsWith("mock-identity-config-oidc")); + assertEquals(Arrays.asList("fizz", "buzz"), idpInfo.getRequestScopes()); assertEquals(Duration.ofMinutes(5), context.getTimeout()); - return onRequest.onRequest(context); + return callback1.onRequest(context); }; - MongoClientSettings clientSettings = createHumanSettings(getOidcUri(), onRequest2, null); + MongoClientSettings clientSettings = createHumanSettings(callback2, null); try (MongoClient mongoClient = createMongoClient(clientSettings)) { performFind(mongoClient); // Ensure that callback was called - assertEquals(1, onRequest.getInvocations()); + assertEquals(1, callback1.getInvocations()); } } @Test public void testh2p2HumanCallbackReturnsMissingData() { + assumeTestEnvironment(); //noinspection ConstantConditions - OidcCallback onRequestNull = (context) -> null; - performFind(createHumanSettings(getOidcUri(), onRequestNull, null), + OidcCallback callbackNull = (context) -> null; + assertFindFails(createHumanSettings(callbackNull, null), MongoConfigurationException.class, "Result of callback must not be null"); //noinspection ConstantConditions - OidcCallback onRequest = (context) -> new OidcCallbackResult(null, Duration.ZERO); - performFind(createHumanSettings(getOidcUri(), onRequest, null), + OidcCallback callback = + (context) -> new OidcCallbackResult(null); + assertFindFails(createHumanSettings(callback, null), IllegalArgumentException.class, "accessToken can not be null"); + } + // not a prose test + @Test + public void testRefreshTokenAbsent() { // additionally, check validation for refresh in machine workflow: - OidcCallback onRequestMachineRefresh = (context) -> new OidcCallbackResult("access", Duration.ZERO, "exists"); - performFind(createSettings(getOidcUri(), onRequestMachineRefresh, null), + OidcCallback callbackMachineRefresh = + (context) -> new OidcCallbackResult("access", Duration.ZERO, "exists"); + assertFindFails(createSettings(callbackMachineRefresh), MongoConfigurationException.class, "Refresh token must only be provided in human workflow"); } @Test - public void testh3p1UsesSpecAuthIfCachedToken() { - failCommandAndCloseConnection("find", 1); - MongoClientSettings settings = createHumanSettings(getOidcUri(), createHumanCallback(), null); + public void testh2p3RefreshTokenPassed() { + assumeTestEnvironment(); + AtomicInteger refreshTokensProvided = new AtomicInteger(); + TestCallback callback1 = createHumanCallback(); + OidcCallback callback2 = (context) -> { + if (context.getRefreshToken() != null) { + refreshTokensProvided.incrementAndGet(); + } + return callback1.onRequest(context); + }; + MongoClientSettings clientSettings = createHumanSettings(callback2, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(2, callback1.getInvocations()); + assertEquals(1, refreshTokensProvided.get()); + } + } - try (MongoClient mongoClient = createMongoClient(settings)) { + @Test + public void testh3p1UsesSpecAuthIfCachedToken() { + assumeTestEnvironment(); + MongoClientSettings clientSettings = createHumanSettings(createHumanCallback(), null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + failCommandAndCloseConnection("find", 1); assertCause(MongoSocketException.class, "Prematurely reached end of stream", () -> performFind(mongoClient)); - failCommand(20, 99, "saslStart"); - + failCommand(18, 1, "saslStart"); performFind(mongoClient); } } @Test public void testh3p2NoSpecAuthIfNoCachedToken() { - failCommand(20, 99, "saslStart"); + assumeTestEnvironment(); + failCommand(18, 1, "saslStart"); TestListener listener = new TestListener(); TestCommandListener commandListener = new TestCommandListener(listener); - performFind(createHumanSettings(getOidcUri(), createHumanCallback(), commandListener), + assertFindFails(createHumanSettings(createHumanCallback(), commandListener), MongoCommandException.class, - "Command failed with error 20"); + "Command failed with error 18"); assertEquals(Arrays.asList( "isMaster started", "isMaster succeeded", @@ -451,18 +611,19 @@ public void testh3p2NoSpecAuthIfNoCachedToken() { } @Test - public void testh4p1Succeeds() { + public void testh4p1ReauthenticationSucceeds() { + assumeTestEnvironment(); TestListener listener = new TestListener(); TestCommandListener commandListener = new TestCommandListener(listener); TestCallback callback = createHumanCallback() .setEventListener(listener); - MongoClientSettings settings = createHumanSettings(getOidcUri(), callback, commandListener); - try (MongoClient mongoClient = createMongoClient(settings)) { + MongoClientSettings clientSettings = createHumanSettings(callback, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { performFind(mongoClient); listener.clear(); assertEquals(1, callback.getInvocations()); - failCommand(391, 1, "find"); + // Perform another find operation that succeeds. performFind(mongoClient); assertEquals(Arrays.asList( // first find fails: @@ -482,27 +643,66 @@ public void testh4p1Succeeds() { @Test public void testh4p2SucceedsNoRefresh() { - TestListener listener = new TestListener(); - TestCommandListener commandListener = new TestCommandListener(listener); - TestCallback callback = createHumanCallback().setEventListener(listener); - MongoClientSettings settings = createHumanSettings(getOidcUri(), callback, commandListener); - try (MongoClient mongoClient = createMongoClient(settings)) { - + assumeTestEnvironment(); + TestCallback callback = createHumanCallback(); + MongoClientSettings clientSettings = createHumanSettings(callback, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { performFind(mongoClient); - listener.clear(); assertEquals(1, callback.getInvocations()); failCommand(391, 1, "find"); performFind(mongoClient); + assertEquals(2, callback.getInvocations()); } } - // TODO-OIDC awaiting spec updates, add 4.3 and 4.4 + @Test + public void testh4p3SucceedsAfterRefreshFails() { + assumeTestEnvironment(); + TestCallback callback1 = createHumanCallback(); + OidcCallback callback2 = (context) -> { + OidcCallbackResult oidcCallbackResult = callback1.onRequest(context); + return new OidcCallbackResult(oidcCallbackResult.getAccessToken(), Duration.ofMinutes(5), "BAD_REFRESH"); + }; + MongoClientSettings clientSettings = createHumanSettings(callback2, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(2, callback1.getInvocations()); + } + } + + @Test + public void testh4p4Fails() { + assumeTestEnvironment(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires"); + TestCallback callback1 = createHumanCallback() + .setPathSupplier(() -> tokens.remove()); + OidcCallback callback2 = (context) -> { + OidcCallbackResult oidcCallbackResult = callback1.onRequest(context); + return new OidcCallbackResult(oidcCallbackResult.getAccessToken(), Duration.ofMinutes(5), "BAD_REFRESH"); + }; + MongoClientSettings clientSettings = createHumanSettings(callback2, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + assertEquals(1, callback1.getInvocations()); + failCommand(391, 1, "find"); + assertCause(MongoCommandException.class, + "Command failed with error 18", + () -> performFind(mongoClient)); + assertEquals(3, callback1.getInvocations()); + } + } // Not a prose test @Test public void testErrorClearsCache() { + assumeTestEnvironment(); // #. Create a new client with a valid request callback that // gives credentials that expire within 5 minutes and // a refresh callback that gives invalid credentials. @@ -512,14 +712,14 @@ public void testErrorClearsCache() { "test_user1_expires", "test_user1_expires", "test_user1_1"); - TestCallback onRequest = createHumanCallback() + TestCallback callback = createHumanCallback() .setRefreshToken("refresh") .setPathSupplier(() -> tokens.remove()) .setEventListener(listener); TestCommandListener commandListener = new TestCommandListener(listener); - MongoClientSettings clientSettings = createHumanSettings(getOidcUri(), onRequest, commandListener); + MongoClientSettings clientSettings = createHumanSettings(callback, commandListener); try (MongoClient mongoClient = createMongoClient(clientSettings)) { // #. Ensure that a find operation adds a new entry to the cache. performFind(mongoClient); @@ -585,17 +785,31 @@ public void testErrorClearsCache() { } } + + private MongoClientSettings createSettings(final OidcCallback callback) { + return createSettings(getOidcUri(), callback, null); + } + public MongoClientSettings createSettings( final String connectionString, - @Nullable final OidcCallback onRequest) { - return createSettings(connectionString, onRequest, null); + @Nullable final TestCallback callback) { + return createSettings(connectionString, callback, null); } private MongoClientSettings createSettings( final String connectionString, @Nullable final OidcCallback callback, @Nullable final CommandListener commandListener) { - return createSettings(connectionString, callback, commandListener, OIDC_CALLBACK_KEY); + String cleanedConnectionString = callback == null ? connectionString : connectionString + .replace("ENVIRONMENT:azure,", "") + .replace("ENVIRONMENT:gcp,", "") + .replace("ENVIRONMENT:test,", ""); + return createSettings(cleanedConnectionString, callback, commandListener, OIDC_CALLBACK_KEY); + } + + private MongoClientSettings createHumanSettings( + final OidcCallback callback, @Nullable final TestCommandListener commandListener) { + return createHumanSettings(getOidcUri(), callback, commandListener); } private MongoClientSettings createHumanSettings( @@ -605,15 +819,16 @@ private MongoClientSettings createHumanSettings( return createSettings(connectionString, callback, commandListener, OIDC_HUMAN_CALLBACK_KEY); } - @NotNull private MongoClientSettings createSettings( final String connectionString, - @Nullable final OidcCallback onRequest, + final @Nullable OidcCallback callback, @Nullable final CommandListener commandListener, final String oidcCallbackKey) { ConnectionString cs = new ConnectionString(connectionString); - MongoCredential credential = cs.getCredential() - .withMechanismProperty(oidcCallbackKey, onRequest); + MongoCredential credential = assertNotNull(cs.getCredential()); + if (callback != null) { + credential = credential.withMechanismProperty(oidcCallbackKey, callback); + } MongoClientSettings.Builder builder = MongoClientSettings.builder() .applicationName(appName) .applyConnectionString(cs) @@ -627,13 +842,13 @@ private MongoClientSettings createSettings( private MongoClientSettings createSettings( final String connectionString, - @Nullable final OidcCallback onRequest, + @Nullable final OidcCallback callback, @Nullable final CommandListener commandListener, final String oidcCallbackKey, @Nullable final List allowedHosts) { ConnectionString cs = new ConnectionString(connectionString); MongoCredential credential = cs.getCredential() - .withMechanismProperty(oidcCallbackKey, onRequest) + .withMechanismProperty(oidcCallbackKey, callback) .withMechanismProperty(ALLOWED_HOSTS_KEY, allowedHosts); MongoClientSettings.Builder builder = MongoClientSettings.builder() .applicationName(appName) @@ -645,13 +860,29 @@ private MongoClientSettings createSettings( return builder.build(); } + private MongoClientSettings createSettingsMulti(@Nullable final String user, final OidcCallback callback) { + return createSettingsHuman(user, callback, getOidcUriMulti()); + } + + private MongoClientSettings createSettingsHuman(@Nullable final String user, final OidcCallback callback, final String oidcUri) { + ConnectionString cs = new ConnectionString(oidcUri); + MongoCredential credential = MongoCredential.createOidcCredential(user) + .withMechanismProperty(OIDC_HUMAN_CALLBACK_KEY, callback); + return MongoClientSettings.builder() + .applicationName(appName) + .applyConnectionString(cs) + .retryReads(false) + .credential(credential) + .build(); + } + private void performFind(final MongoClientSettings settings) { try (MongoClient mongoClient = createMongoClient(settings)) { performFind(mongoClient); } } - private void performFind( + private void assertFindFails( final MongoClientSettings settings, final Class expectedExceptionOrCause, final String expectedMessage) { @@ -670,27 +901,21 @@ private void performFind(final MongoClient mongoClient) { private static void assertCause( final Class expectedCause, final String expectedMessageFragment, final Executable e) { - Throwable actualException = assertThrows(Throwable.class, e); - assertCause(expectedCause, expectedMessageFragment, actualException); - } - - private static void assertCause( - final Class expectedCause, final String expectedMessageFragment, final Throwable actualException) { - Throwable cause = actualException; + Throwable cause = assertThrows(Throwable.class, e); while (cause.getCause() != null) { cause = cause.getCause(); } - if (!expectedCause.isInstance(cause)) { - throw new AssertionFailedError("Unexpected cause", actualException); - } if (!cause.getMessage().contains(expectedMessageFragment)) { - throw new AssertionFailedError("Unexpected message", actualException); + throw new AssertionFailedError("Unexpected message: " + cause.getMessage(), cause); + } + if (!expectedCause.isInstance(cause)) { + throw new AssertionFailedError("Unexpected cause: " + cause.getClass(), assertThrows(Throwable.class, e)); } } protected void delayNextFind() { - try (MongoClient client = createMongoClient(createSettings( - getAwsOidcUri(), null, null))) { + + try (MongoClient client = createMongoClient(Fixture.getMongoClientSettings())) { BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(1))) .append("data", new BsonDocument() @@ -703,8 +928,7 @@ protected void delayNextFind() { } protected void failCommand(final int code, final int times, final String... commands) { - try (MongoClient mongoClient = createMongoClient(createSettings( - getAwsOidcUri(), null, null))) { + try (MongoClient mongoClient = createMongoClient(Fixture.getMongoClientSettings())) { List list = Arrays.stream(commands).map(c -> new BsonString(c)).collect(Collectors.toList()); BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(times))) @@ -717,8 +941,7 @@ protected void failCommand(final int code, final int times, final String... comm } private void failCommandAndCloseConnection(final String command, final int times) { - try (MongoClient mongoClient = createMongoClient(createSettings( - getAwsOidcUri(), null, null))) { + try (MongoClient mongoClient = createMongoClient(Fixture.getMongoClientSettings())) { BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(times))) .append("data", new BsonDocument() @@ -772,11 +995,10 @@ public OidcCallbackResult onRequest(final OidcCallbackContext context) { + " - IdpInfo: " + (context.getIdpInfo() == null ? "none" : "present") + ")"); } - return callback(); + return callback(context); } - @NotNull - private OidcCallbackResult callback() { + private OidcCallbackResult callback(final OidcCallbackContext context) { if (concurrentTracker != null) { if (concurrentTracker.get() > 0) { throw new RuntimeException("Callbacks should not be invoked by multiple threads."); @@ -785,20 +1007,23 @@ private OidcCallbackResult callback() { } try { invocations.incrementAndGet(); - Path path = Paths.get(pathSupplier == null - ? getAwsTokenFilePath() - : pathSupplier.get()); - String accessToken; try { simulateDelay(); - accessToken = new String(Files.readAllBytes(path), StandardCharsets.UTF_8); - } catch (IOException | InterruptedException e) { + } catch (InterruptedException e) { throw new RuntimeException(e); } - if (testListener != null) { - testListener.add("read access token: " + path.getFileName()); + MongoCredential credential = assertNotNull(new ConnectionString(getOidcUri()).getCredential()); + String oidcEnv = getOidcEnv(); + OidcCallback c; + if (oidcEnv.contains("azure")) { + c = OidcAuthenticator.getAzureCallback(credential); + } else if (oidcEnv.contains("gcp")) { + c = OidcAuthenticator.getGcpCallback(credential); + } else { + c = getProseTestCallback(); } - return new OidcCallbackResult(accessToken, Duration.ZERO, refreshToken); + return c.onRequest(context); + } finally { if (concurrentTracker != null) { concurrentTracker.decrementAndGet(); @@ -806,6 +1031,23 @@ private OidcCallbackResult callback() { } } + private OidcCallback getProseTestCallback() { + return (x) -> { + try { + Path path = Paths.get(pathSupplier == null + ? getTestTokenFilePath() + : pathSupplier.get()); + String accessToken = new String(Files.readAllBytes(path), StandardCharsets.UTF_8); + if (testListener != null) { + testListener.add("read access token: " + path.getFileName()); + } + return new OidcCallbackResult(accessToken, Duration.ZERO, refreshToken); + } catch (IOException e) { + throw new RuntimeException(e); + } + }; + } + private void simulateDelay() throws InterruptedException { if (delayInMilliseconds != null) { Thread.sleep(delayInMilliseconds); @@ -847,6 +1089,7 @@ public TestCallback setPathSupplier(final Supplier pathSupplier) { this.testListener, pathSupplier); } + public TestCallback setRefreshToken(final String token) { return new TestCallback( token, @@ -857,7 +1100,6 @@ public TestCallback setRefreshToken(final String token) { } } - @NotNull private ConcurrentLinkedQueue tokenQueue(final String... queue) { String tokenPath = oidcTokenDirectory(); return java.util.stream.Stream From 839200636919c856ac40ffc75a003af2935453e6 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 29 Apr 2024 18:23:59 -0600 Subject: [PATCH 6/6] Doc fix --- driver-core/src/main/com/mongodb/ConnectionString.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/driver-core/src/main/com/mongodb/ConnectionString.java b/driver-core/src/main/com/mongodb/ConnectionString.java index ae795a65bba..34378d4069f 100644 --- a/driver-core/src/main/com/mongodb/ConnectionString.java +++ b/driver-core/src/main/com/mongodb/ConnectionString.java @@ -241,7 +241,7 @@ *
  • *
  • {@code authMechanismProperties=PROPERTY_NAME:PROPERTY_VALUE,PROPERTY_NAME2:PROPERTY_VALUE2}: This option allows authentication * mechanism properties to be set on the connection string. Property values must be percent-encoded individually, when - * separator or escape characters are used (including {@code ,} (comma), {@code =}, {@code +}, {@code &}, and {@code %}). The + * special characters are used, including {@code ,} (comma), {@code =}, {@code +}, {@code &}, and {@code %}. The * entire substring following the {@code =} should not itself be encoded. *
  • *
  • {@code gssapiServiceName=string}: This option only applies to the GSSAPI mechanism and is used to alter the service name.