diff --git a/.evergreen/prepare-oidc-get-tokens-docker.sh b/.evergreen/prepare-oidc-get-tokens-docker.sh new file mode 100755 index 00000000000..e904d5d2b89 --- /dev/null +++ b/.evergreen/prepare-oidc-get-tokens-docker.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# DRIVERS_TOOLS The path to evergreeen tools +# OIDC_AWS_* Required OIDC_AWS_* env variables must be configured +# +# Environment variables used as output: +# OIDC_TESTS_ENABLED Allows running OIDC tests +# OIDC_TOKEN_DIR The path to generated OIDC AWS tokens +# AWS_WEB_IDENTITY_TOKEN_FILE The path to AWS token for device workflow + +if [ -z ${DRIVERS_TOOLS+x} ]; then + echo "DRIVERS_TOOLS. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ROLE_ARN+x} ]; then + echo "OIDC_AWS_ROLE_ARN. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_SECRET_ACCESS_KEY+x} ]; then + echo "OIDC_AWS_SECRET_ACCESS_KEY. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ACCESS_KEY_ID+x} ]; then + echo "OIDC_AWS_ACCESS_KEY_ID. is not set"; + exit 1 +fi + +export AWS_ROLE_ARN=${OIDC_AWS_ROLE_ARN} +export AWS_SECRET_ACCESS_KEY=${OIDC_AWS_SECRET_ACCESS_KEY} +export AWS_ACCESS_KEY_ID=${OIDC_AWS_ACCESS_KEY_ID} +export OIDC_FOLDER=${DRIVERS_TOOLS}/.evergreen/auth_oidc +export OIDC_TOKEN_DIR=${OIDC_FOLDER}/test_tokens +export AWS_WEB_IDENTITY_TOKEN_FILE=${OIDC_TOKEN_DIR}/test1 +export OIDC_TESTS_ENABLED=true + +echo "Configuring OIDC server for local authentication tests" + +cd ${OIDC_FOLDER} +DRIVERS_TOOLS=${DRIVERS_TOOLS} ./oidc_get_tokens.sh \ No newline at end of file diff --git a/.evergreen/prepare-oidc-server-docker.sh b/.evergreen/prepare-oidc-server-docker.sh new file mode 100755 index 00000000000..0fcd1ed4194 --- /dev/null +++ b/.evergreen/prepare-oidc-server-docker.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# DRIVERS_TOOLS The path to evergreeen tools +# OIDC_AWS_* OIDC_AWS_* env variables must be configured +# +# Environment variables used as output: +# OIDC_TESTS_ENABLED Allows running OIDC tests +# OIDC_TOKEN_DIR The path to generated tokens +# AWS_WEB_IDENTITY_TOKEN_FILE The path to AWS token for device workflow + +if [ -z ${DRIVERS_TOOLS+x} ]; then + echo "DRIVERS_TOOLS. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ROLE_ARN+x} ]; then + echo "OIDC_AWS_ROLE_ARN. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_SECRET_ACCESS_KEY+x} ]; then + echo "OIDC_AWS_SECRET_ACCESS_KEY. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ACCESS_KEY_ID+x} ]; then + echo "OIDC_AWS_ACCESS_KEY_ID. is not set"; + exit 1 +fi + +export AWS_ROLE_ARN=${OIDC_AWS_ROLE_ARN} +export AWS_SECRET_ACCESS_KEY=${OIDC_AWS_SECRET_ACCESS_KEY} +export AWS_ACCESS_KEY_ID=${OIDC_AWS_ACCESS_KEY_ID} +export OIDC_FOLDER=${DRIVERS_TOOLS}/.evergreen/auth_oidc +export OIDC_TOKEN_DIR=${OIDC_FOLDER}/test_tokens +export AWS_WEB_IDENTITY_TOKEN_FILE=${OIDC_TOKEN_DIR}/test1 +export OIDC_TESTS_ENABLED=true + +echo "Configuring OIDC server for local authentication tests" + +cd ${OIDC_FOLDER} +DRIVERS_TOOLS=${DRIVERS_TOOLS} ./start_local_server.sh \ No newline at end of file diff --git a/bson/src/test/unit/util/ThreadTestHelpers.java b/bson/src/test/unit/util/ThreadTestHelpers.java index a4767c503f9..e2115da079f 100644 --- a/bson/src/test/unit/util/ThreadTestHelpers.java +++ b/bson/src/test/unit/util/ThreadTestHelpers.java @@ -31,15 +31,19 @@ private ThreadTestHelpers() { } public static void executeAll(final int nThreads, final Runnable c) { + executeAll(Collections.nCopies(nThreads, c).toArray(new Runnable[0])); + } + + public static void executeAll(final Runnable... runnables) { ExecutorService service = null; try { - service = Executors.newFixedThreadPool(nThreads); - CountDownLatch latch = new CountDownLatch(nThreads); + service = Executors.newFixedThreadPool(runnables.length); + CountDownLatch latch = new CountDownLatch(runnables.length); List failures = Collections.synchronizedList(new ArrayList<>()); - for (int i = 0; i < nThreads; i++) { + for (final Runnable runnable : runnables) { service.submit(() -> { try { - c.run(); + runnable.run(); } catch (Throwable e) { failures.add(e); } finally { diff --git a/driver-core/src/main/com/mongodb/AuthenticationMechanism.java b/driver-core/src/main/com/mongodb/AuthenticationMechanism.java index db8a909b79d..7a7b7415ef6 100644 --- a/driver-core/src/main/com/mongodb/AuthenticationMechanism.java +++ b/driver-core/src/main/com/mongodb/AuthenticationMechanism.java @@ -37,6 +37,13 @@ public enum AuthenticationMechanism { */ MONGODB_AWS("MONGODB-AWS"), + /** + * The MONGODB-OIDC mechanism. + * @since 4.10 + * @mongodb.server.release 7.0 + */ + MONGODB_OIDC("MONGODB-OIDC"), + /** * The MongoDB X.509 mechanism. This mechanism is available only with client certificates over SSL. */ diff --git a/driver-core/src/main/com/mongodb/ConnectionString.java b/driver-core/src/main/com/mongodb/ConnectionString.java index cd8f359f35c..902a57bc495 100644 --- a/driver-core/src/main/com/mongodb/ConnectionString.java +++ b/driver-core/src/main/com/mongodb/ConnectionString.java @@ -40,6 +40,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateCreateOidcCredential; import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; @@ -919,6 +920,10 @@ private MongoCredential createMongoCredentialWithMechanism(final AuthenticationM case MONGODB_AWS: credential = MongoCredential.createAwsCredential(userName, password); break; + case MONGODB_OIDC: + validateCreateOidcCredential(password); + credential = MongoCredential.createOidcCredential(userName); + break; default: throw new UnsupportedOperationException(format("The connection string contains an invalid authentication mechanism'. " + "'%s' is not a supported authentication mechanism", diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index ffa2a3c4e02..418863dc21c 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -17,22 +17,27 @@ package com.mongodb; import com.mongodb.annotations.Beta; +import com.mongodb.annotations.Evolving; import com.mongodb.annotations.Immutable; import com.mongodb.lang.Nullable; +import java.time.Duration; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import static com.mongodb.AuthenticationMechanism.GSSAPI; import static com.mongodb.AuthenticationMechanism.MONGODB_AWS; +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; import static com.mongodb.AuthenticationMechanism.MONGODB_X509; import static com.mongodb.AuthenticationMechanism.PLAIN; import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_1; import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_256; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateOidcCredentialConstruction; /** * Represents credentials to authenticate to a mongo server,as well as the source of the credentials and the authentication mechanism to @@ -179,6 +184,70 @@ public final class MongoCredential { @Beta(Beta.Reason.CLIENT) public static final String AWS_CREDENTIAL_PROVIDER_KEY = "AWS_CREDENTIAL_PROVIDER"; + /** + * The provider name. The value must be a string. + *

+ * If this is provided, + * {@link MongoCredential#REQUEST_TOKEN_CALLBACK_KEY} and + * {@link MongoCredential#REFRESH_TOKEN_CALLBACK_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String PROVIDER_NAME_KEY = "PROVIDER_NAME"; + + /** + * This callback is invoked when the OIDC-based authenticator requests + * tokens from the identity provider. The type of the value must be + * {@link OidcRequestCallback}. + *

+ * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String REQUEST_TOKEN_CALLBACK_KEY = "REQUEST_TOKEN_CALLBACK"; + + /** + * Mechanism key for invoked when the OIDC-based authenticator refreshes + * tokens from the identity provider. If this callback is not provided, + * then refresh operations will not be attempted.The type of the value + * must be {@link OidcRefreshCallback}. + *

+ * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String REFRESH_TOKEN_CALLBACK_KEY = "REFRESH_TOKEN_CALLBACK"; + + /** + * Mechanism key for a list of allowed hostnames or ip-addresses for MongoDB connections. Ports must be excluded. + * The hostnames may include a leading "*." wildcard, which allows for matching (potentially nested) subdomains. + * When MONGODB-OIDC authentication is attempted against a hostname that does not match any of list of allowed hosts + * the driver will raise an error. The type of the value must be {@code List}. + * + * @see MongoCredential#DEFAULT_ALLOWED_HOSTS + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String ALLOWED_HOSTS_KEY = "ALLOWED_HOSTS"; + + /** + * The list of allowed hosts that will be used if no + * {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied. + * The default allowed hosts are: + * {@code "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"} + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final List DEFAULT_ALLOWED_HOSTS = Collections.unmodifiableList(Arrays.asList( + "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")); + /** * Creates a MongoCredential instance with an unspecified mechanism. The client will negotiate the best mechanism based on the * version of the server that the client is authenticating to. @@ -327,6 +396,23 @@ public static MongoCredential createAwsCredential(@Nullable final String userNam return new MongoCredential(MONGODB_AWS, userName, "$external", password); } + /** + * Creates a MongoCredential instance for the MONGODB-OIDC mechanism. + * + * @param userName the user name, which may be null. This is the OIDC principal name. + * @return the credential + * @since 4.10 + * @see #withMechanismProperty(String, Object) + * @see #PROVIDER_NAME_KEY + * @see #REQUEST_TOKEN_CALLBACK_KEY + * @see #REFRESH_TOKEN_CALLBACK_KEY + * @see #ALLOWED_HOSTS_KEY + * @mongodb.server.release 7.0 + */ + public static MongoCredential createOidcCredential(@Nullable final String userName) { + return new MongoCredential(MONGODB_OIDC, userName, "$external", null); + } + /** * Creates a new MongoCredential as a copy of this instance, with the specified mechanism property added. * @@ -370,7 +456,11 @@ public MongoCredential withMechanism(final AuthenticationMechanism mechanism) { MongoCredential(@Nullable final AuthenticationMechanism mechanism, @Nullable final String userName, final String source, @Nullable final char[] password, final Map mechanismProperties) { - if (userName == null && !Arrays.asList(MONGODB_X509, MONGODB_AWS).contains(mechanism)) { + if (mechanism == MONGODB_OIDC) { + validateOidcCredentialConstruction(source, mechanismProperties); + } + + if (userName == null && !Arrays.asList(MONGODB_X509, MONGODB_AWS, MONGODB_OIDC).contains(mechanism)) { throw new IllegalArgumentException("username can not be null"); } @@ -543,4 +633,136 @@ public String toString() { + ", mechanismProperties=" + '}'; } + + /** + * The context for the {@link OidcRequestCallback#onRequest(OidcRequestContext) OIDC request callback}. + */ + @Evolving + public interface OidcRequestContext { + /** + * @return The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + IdpInfo getIdpInfo(); + + /** + * @return The timeout that this callback must complete within. + */ + Duration getTimeout(); + } + + /** + * The context for the {@link OidcRefreshCallback#onRefresh(OidcRefreshContext) OIDC refresh callback}. + */ + @Evolving + public interface OidcRefreshContext extends OidcRequestContext { + /** + * @return The OIDC Refresh token supplied by a prior callback invocation. + */ + String getRefreshToken(); + } + + /** + * This callback is invoked when the OIDC-based authenticator requests + * tokens from the identity provider. + *

+ * It does not have to be thread-safe, unless it is provided to multiple + * MongoClients. + */ + public interface OidcRequestCallback { + /** + * @param context The context. + * @return The response produced by an OIDC Identity Provider + */ + IdpResponse onRequest(OidcRequestContext context); + } + + /** + * This callback is invoked when the OIDC-based authenticator refreshes + * tokens from the identity provider. If this callback is not provided, + * then refresh operations will not be attempted. + *

+ * It does not have to be thread-safe, unless it is provided to multiple + * MongoClients. + */ + public interface OidcRefreshCallback { + /** + * @param context The context. + * @return The response produced by an OIDC Identity Provider + */ + IdpResponse onRefresh(OidcRefreshContext context); + } + + /** + * The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + @Evolving + public interface IdpInfo { + /** + * @return URL which describes the Authorization Server. This identifier is the + * iss of provided access tokens, and is viable for RFC8414 metadata + * discovery and RFC9207 identification. + */ + String getIssuer(); + + /** + * @return Unique client ID for this OIDC client. + */ + String getClientId(); + + /** + * @return Additional scopes to request from Identity Provider. Immutable. + */ + List getRequestScopes(); + } + + /** + * The response produced by an OIDC Identity Provider. + */ + public static final class IdpResponse { + + private final String accessToken; + + @Nullable + private final Integer accessTokenExpiresInSeconds; + + @Nullable + private final String refreshToken; + + /** + * @param accessToken The OIDC access token + * @param accessTokenExpiresInSeconds The expiration in seconds. If null, the access token is single-use. + * @param refreshToken The refresh token. If null, refresh will not be attempted. + */ + public IdpResponse(final String accessToken, @Nullable final Integer accessTokenExpiresInSeconds, + @Nullable final String refreshToken) { + notNull("accessToken", accessToken); + this.accessToken = accessToken; + this.accessTokenExpiresInSeconds = accessTokenExpiresInSeconds; + this.refreshToken = refreshToken; + } + + /** + * @return The OIDC access token. + */ + public String getAccessToken() { + return accessToken; + } + + /** + * @return The expiration time for the access token in seconds. + * If null, the access token is single-use. + */ + @Nullable + public Integer getAccessTokenExpiresInSeconds() { + return accessTokenExpiresInSeconds; + } + + /** + * @return The OIDC refresh token. If null, refresh will not be attempted. + */ + @Nullable + public String getRefreshToken() { + return refreshToken; + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index 25fbe199966..c06eddcc6dd 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -19,6 +19,7 @@ import com.mongodb.MongoInterruptedException; import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.StampedLock; import java.util.function.Supplier; /** @@ -32,6 +33,21 @@ public static void withLock(final Lock lock, final Runnable action) { }); } + public static V withLock(final StampedLock lock, final Supplier supplier) { + long stamp; + try { + stamp = lock.writeLockInterruptibly(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new MongoInterruptedException("Interrupted waiting for lock", e); + } + try { + return supplier.get(); + } finally { + lock.unlockWrite(stamp); + } + } + public static V withLock(final Lock lock, final Supplier supplier) { return checkedWithLock(lock, supplier::get); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index 9ec4780d958..45e0b078452 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -21,6 +21,7 @@ import com.mongodb.ServerApi; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerType; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; @@ -42,6 +43,11 @@ public abstract class Authenticator { this.serverApi = serverApi; } + public static boolean shouldAuthenticate(@Nullable final Authenticator authenticator, + final ConnectionDescription connectionDescription) { + return authenticator != null && connectionDescription.getServerType() != ServerType.REPLICA_SET_ARBITER; + } + @NonNull MongoCredentialWithCache getMongoCredentialWithCache() { return credential; @@ -93,4 +99,9 @@ T getNonNullMechanismProperty(final String key, @Nullable final T defaultVal abstract void authenticateAsync(InternalConnection connection, ConnectionDescription connectionDescription, SingleResultCallback callback); + + public void reauthenticate(final InternalConnection connection) { + authenticate(connection, connection.getDescription()); + } + } diff --git a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java index ec0fc3f9c8f..35f9f8120ee 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java @@ -16,7 +16,6 @@ package com.mongodb.internal.connection; -import com.mongodb.AuthenticationMechanism; import com.mongodb.AwsCredential; import com.mongodb.MongoClientException; import com.mongodb.MongoCredential; @@ -27,14 +26,10 @@ import com.mongodb.internal.authentication.AwsCredentialHelper; import com.mongodb.lang.Nullable; import org.bson.BsonBinary; -import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; import org.bson.RawBsonDocument; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.EncoderContext; -import org.bson.io.BasicOutputBuffer; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; @@ -77,27 +72,12 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { return new AwsSaslClient(getMongoCredential()); } - private static class AwsSaslClient implements SaslClient { - private final MongoCredential credential; + private static class AwsSaslClient extends SaslClientImpl { private final byte[] clientNonce = new byte[RANDOM_LENGTH]; private int step = -1; AwsSaslClient(final MongoCredential credential) { - this.credential = credential; - } - - @Override - public String getMechanismName() { - AuthenticationMechanism authMechanism = credential.getAuthenticationMechanism(); - if (authMechanism == null) { - throw new IllegalArgumentException("Authentication mechanism cannot be null"); - } - return authMechanism.getMechanismName(); - } - - @Override - public boolean hasInitialResponse() { - return true; + super(credential); } @Override @@ -117,26 +97,6 @@ public boolean isComplete() { return step == 1; } - @Override - public byte[] unwrap(final byte[] bytes, final int i, final int i1) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public byte[] wrap(final byte[] bytes, final int i, final int i1) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public Object getNegotiatedProperty(final String s) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public void dispose() { - // nothing to do - } - private byte[] computeClientFirstMessage() { new SecureRandom().nextBytes(this.clientNonce); @@ -184,6 +144,7 @@ private byte[] computeClientFinalMessage(final byte[] serverFirst) throws SaslEx private AwsCredential createAwsCredential() { AwsCredential awsCredential; + MongoCredential credential = getCredential(); if (credential.getUserName() != null) { if (credential.getPassword() == null) { throw new MongoClientException("secretAccessKey is required for AWS credential"); @@ -207,13 +168,4 @@ private AwsCredential createAwsCredential() { return awsCredential; } } - - private static byte[] toBson(final BsonDocument document) { - byte[] bytes; - BasicOutputBuffer buffer = new BasicOutputBuffer(); - new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); - bytes = new byte[buffer.size()]; - System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); - return bytes; - } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index e5526c175d6..5a3e2b523ad 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -18,6 +18,7 @@ import com.mongodb.LoggerSettings; import com.mongodb.MongoClientException; +import com.mongodb.MongoCommandException; import com.mongodb.MongoCompressor; import com.mongodb.MongoException; import com.mongodb.MongoInternalException; @@ -68,6 +69,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.isTrue; @@ -95,6 +97,19 @@ @NotThreadSafe public class InternalStreamConnection implements InternalConnection { + private static volatile boolean recordEverything = false; + + /** + * Will attempt to record events to the command listener that are usually + * suppressed. + * + * @param recordEverything whether to attempt to record everything + */ + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + public static void setRecordEverything(final boolean recordEverything) { + InternalStreamConnection.recordEverything = recordEverything; + } + private static final Set SECURITY_SENSITIVE_COMMANDS = new HashSet<>(asList( "authenticate", "saslStart", @@ -114,6 +129,8 @@ public class InternalStreamConnection implements InternalConnection { private static final Logger LOGGER = Loggers.getLogger("connection"); private final ClusterConnectionMode clusterConnectionMode; + @Nullable + private final Authenticator authenticator; private final boolean isMonitoringConnection; private final ServerId serverId; private final ConnectionGenerationSupplier connectionGenerationSupplier; @@ -127,6 +144,7 @@ public class InternalStreamConnection implements InternalConnection { private final AtomicBoolean isClosed = new AtomicBoolean(); private final AtomicBoolean opened = new AtomicBoolean(); + private final AtomicBoolean authenticated = new AtomicBoolean(); private final List compressorList; private final LoggerSettings loggerSettings; @@ -153,11 +171,13 @@ public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMod final StreamFactory streamFactory, final List compressorList, final CommandListener commandListener, final InternalConnectionInitializer connectionInitializer, @Nullable final InetAddressResolver inetAddressResolver) { - this(clusterConnectionMode, false, serverId, connectionGenerationSupplier, streamFactory, compressorList, + this(clusterConnectionMode, null, false, serverId, connectionGenerationSupplier, streamFactory, compressorList, LoggerSettings.builder().build(), commandListener, connectionInitializer, inetAddressResolver); } - public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMode, final boolean isMonitoringConnection, + public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMode, + @Nullable final Authenticator authenticator, + final boolean isMonitoringConnection, final ServerId serverId, final ConnectionGenerationSupplier connectionGenerationSupplier, final StreamFactory streamFactory, final List compressorList, @@ -165,6 +185,7 @@ public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMod final CommandListener commandListener, final InternalConnectionInitializer connectionInitializer, @Nullable final InetAddressResolver inetAddressResolver) { this.clusterConnectionMode = clusterConnectionMode; + this.authenticator = authenticator; this.isMonitoringConnection = isMonitoringConnection; this.serverId = notNull("serverId", serverId); this.connectionGenerationSupplier = notNull("connectionGeneration", connectionGenerationSupplier); @@ -287,6 +308,7 @@ private void initAfterHandshakeFinish(final InternalConnectionInitializationDesc description = initializationDescription.getConnectionDescription(); initialServerDescription = initializationDescription.getServerDescription(); opened.set(true); + authenticated.set(true); sendCompressor = findSendCompressor(description); } @@ -352,8 +374,35 @@ public boolean isClosed() { @Override public T sendAndReceive(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext) { - CommandEventSender commandEventSender; + Supplier sendAndReceiveInternal = () -> sendAndReceiveInternal( + message, decoder, sessionContext, requestContext, operationContext); + try { + return sendAndReceiveInternal.get(); + } catch (MongoCommandException e) { + if (triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description)) { + authenticated.set(false); + authenticator.reauthenticate(this); + authenticated.set(true); + return sendAndReceiveInternal.get(); + } + throw e; + } + } + + public static boolean triggersReauthentication(@Nullable final Throwable t) { + if (t instanceof MongoCommandException) { + MongoCommandException e = (MongoCommandException) t; + return e.getErrorCode() == 391; + } + return false; + } + + @Nullable + private T sendAndReceiveInternal(final CommandMessage message, final Decoder decoder, + final SessionContext sessionContext, final RequestContext requestContext, + final OperationContext operationContext) { + CommandEventSender commandEventSender; try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this)) { message.encode(bsonOutput, sessionContext); commandEventSender = createCommandEventSender(message, bsonOutput, requestContext, operationContext); @@ -466,7 +515,7 @@ private T receiveCommandMessageResponse(final Decoder decoder, commandEventSender.sendFailedEvent(e); } throw e; - } + } } @Override @@ -860,12 +909,14 @@ public void onResult(@Nullable final ByteBuf result, @Nullable final Throwable t private CommandEventSender createCommandEventSender(final CommandMessage message, final ByteBufferBsonOutput bsonOutput, final RequestContext requestContext, final OperationContext operationContext) { - if (!isMonitoringConnection && opened() && (commandListener != null || COMMAND_PROTOCOL_LOGGER.isRequired(DEBUG, getClusterId()))) { - return new LoggingCommandEventSender(SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, - commandListener, requestContext, operationContext, message, bsonOutput, COMMAND_PROTOCOL_LOGGER, loggerSettings); - } else { + boolean listensOrLogs = commandListener != null || COMMAND_PROTOCOL_LOGGER.isRequired(DEBUG, getClusterId()); + if (!recordEverything && (isMonitoringConnection || !opened() || !authenticated.get() || !listensOrLogs)) { return new NoOpCommandEventSender(); } + return new LoggingCommandEventSender( + SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, + requestContext, operationContext, message, bsonOutput, + COMMAND_PROTOCOL_LOGGER, loggerSettings); } private ClusterId getClusterId() { diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java index 2431a3b800a..f879c642ccd 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java @@ -16,6 +16,7 @@ package com.mongodb.internal.connection; +import com.mongodb.AuthenticationMechanism; import com.mongodb.LoggerSettings; import com.mongodb.MongoCompressor; import com.mongodb.MongoDriverInformation; @@ -30,7 +31,6 @@ import java.util.List; -import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.connection.ClientMetadataHelper.createClientMetadataDocument; @@ -80,18 +80,21 @@ class InternalStreamConnectionFactory implements InternalConnectionFactory { @Override public InternalConnection create(final ServerId serverId, final ConnectionGenerationSupplier connectionGenerationSupplier) { Authenticator authenticator = credential == null ? null : createAuthenticator(credential); - return new InternalStreamConnection(clusterConnectionMode, isMonitoringConnection, serverId, connectionGenerationSupplier, + InternalStreamConnectionInitializer connectionInitializer = new InternalStreamConnectionInitializer( + clusterConnectionMode, authenticator, clientMetadataDocument, compressorList, serverApi); + return new InternalStreamConnection( + clusterConnectionMode, authenticator, + isMonitoringConnection, serverId, connectionGenerationSupplier, streamFactory, compressorList, loggerSettings, commandListener, - new InternalStreamConnectionInitializer(clusterConnectionMode, authenticator, clientMetadataDocument, compressorList, - serverApi), inetAddressResolver); + connectionInitializer, inetAddressResolver); } private Authenticator createAuthenticator(final MongoCredentialWithCache credential) { - if (credential.getAuthenticationMechanism() == null) { + AuthenticationMechanism authenticationMechanism = credential.getAuthenticationMechanism(); + if (authenticationMechanism == null) { return new DefaultAuthenticator(credential, clusterConnectionMode, serverApi); } - - switch (assertNotNull(credential.getAuthenticationMechanism())) { + switch (authenticationMechanism) { case GSSAPI: return new GSSAPIAuthenticator(credential, clusterConnectionMode, serverApi); case PLAIN: @@ -103,8 +106,10 @@ private Authenticator createAuthenticator(final MongoCredentialWithCache credent return new ScramShaAuthenticator(credential, clusterConnectionMode, serverApi); case MONGODB_AWS: return new AwsAuthenticator(credential, clusterConnectionMode, serverApi); + case MONGODB_OIDC: + return new OidcAuthenticator(credential, clusterConnectionMode, serverApi); default: - throw new IllegalArgumentException("Unsupported authentication mechanism " + credential.getAuthenticationMechanism()); + throw new IllegalArgumentException("Unsupported authentication mechanism " + authenticationMechanism); } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index 8b102182c05..753ff235d79 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -25,7 +25,6 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.connection.ConnectionId; import com.mongodb.connection.ServerDescription; -import com.mongodb.connection.ServerType; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import org.bson.BsonArray; @@ -81,8 +80,10 @@ public InternalConnectionInitializationDescription finishHandshake(final Interna final InternalConnectionInitializationDescription description) { notNull("internalConnection", internalConnection); notNull("description", description); - - authenticate(internalConnection, description.getConnectionDescription()); + final ConnectionDescription connectionDescription = description.getConnectionDescription(); + if (Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { + authenticator.authenticate(internalConnection, connectionDescription); + } return completeConnectionDescriptionInitialization(internalConnection, description); } @@ -105,11 +106,12 @@ public void startHandshakeAsync(final InternalConnection internalConnection, public void finishHandshakeAsync(final InternalConnection internalConnection, final InternalConnectionInitializationDescription description, final SingleResultCallback callback) { - if (authenticator == null || description.getConnectionDescription().getServerType() - == ServerType.REPLICA_SET_ARBITER) { + ConnectionDescription connectionDescription = description.getConnectionDescription(); + + if (!Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { completeConnectionDescriptionInitializationAsync(internalConnection, description, callback); } else { - authenticator.authenticateAsync(internalConnection, description.getConnectionDescription(), + authenticator.authenticateAsync(internalConnection, connectionDescription, (result1, t1) -> { if (t1 != null) { callback.onResult(null, t1); @@ -200,12 +202,6 @@ private InternalConnectionInitializationDescription completeConnectionDescriptio description); } - private void authenticate(final InternalConnection internalConnection, final ConnectionDescription connectionDescription) { - if (authenticator != null && connectionDescription.getServerType() != ServerType.REPLICA_SET_ARBITER) { - authenticator.authenticate(internalConnection, connectionDescription); - } - } - private void setSpeculativeAuthenticateResponse(final BsonDocument helloResult) { if (authenticator instanceof SpeculativeAuthenticator) { ((SpeculativeAuthenticator) authenticator).setSpeculativeAuthenticateResponse( diff --git a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java index 6b0f6e4ec3c..3d65efe48ee 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java +++ b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java @@ -22,8 +22,10 @@ import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.locks.StampedLock; import static com.mongodb.internal.Locks.withLock; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcCacheEntry; /** *

This class is not part of the public API and may be removed or changed at any time

@@ -33,12 +35,12 @@ public class MongoCredentialWithCache { private final Cache cache; public MongoCredentialWithCache(final MongoCredential credential) { - this(credential, null); + this(credential, new Cache()); } - private MongoCredentialWithCache(final MongoCredential credential, @Nullable final Cache cache) { + private MongoCredentialWithCache(final MongoCredential credential, final Cache cache) { this.credential = credential; - this.cache = cache != null ? cache : new Cache(); + this.cache = cache; } public MongoCredentialWithCache withMechanism(final AuthenticationMechanism mechanism) { @@ -63,15 +65,34 @@ public void putInCache(final Object key, final Object value) { cache.set(key, value); } + OidcCacheEntry getOidcCacheEntry() { + return cache.oidcCacheEntry; + } + + void setOidcCacheEntry(final OidcCacheEntry oidcCacheEntry) { + this.cache.oidcCacheEntry = oidcCacheEntry; + } + + StampedLock getOidcLock() { + return cache.oidcLock; + } + public Lock getLock() { return cache.lock; } + /** + * Stores any state associated with the credential. + */ static class Cache { private final ReentrantLock lock = new ReentrantLock(); private Object cacheKey; private Object cacheValue; + + private final StampedLock oidcLock = new StampedLock(); + private volatile OidcCacheEntry oidcCacheEntry = new OidcCacheEntry(); + Object get(final Object key) { return withLock(lock, () -> { if (cacheKey != null && cacheKey.equals(key)) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java new file mode 100644 index 00000000000..f3c931a433f --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -0,0 +1,678 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.AuthenticationMechanism; +import com.mongodb.MongoClientException; +import com.mongodb.MongoCommandException; +import com.mongodb.MongoConfigurationException; +import com.mongodb.MongoCredential; +import com.mongodb.MongoCredential.IdpInfo; +import com.mongodb.MongoCredential.IdpResponse; +import com.mongodb.MongoException; +import com.mongodb.MongoSecurityException; +import com.mongodb.ServerAddress; +import com.mongodb.ServerApi; +import com.mongodb.connection.ClusterConnectionMode; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.internal.Locks; +import com.mongodb.internal.Timeout; +import com.mongodb.internal.VisibleForTesting; +import com.mongodb.lang.Nullable; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.RawBsonDocument; +import org.jetbrains.annotations.NotNull; + +import javax.security.sasl.SaslClient; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.DEFAULT_ALLOWED_HOSTS; +import static com.mongodb.MongoCredential.OidcRefreshCallback; +import static com.mongodb.MongoCredential.OidcRefreshContext; +import static com.mongodb.MongoCredential.OidcRequestCallback; +import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; +import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; +import static com.mongodb.assertions.Assertions.assertFalse; +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; +import static java.lang.String.format; + +/** + *

This class is not part of the public API and may be removed or changed at any time

+ */ +public final class OidcAuthenticator extends SaslAuthenticator { + + private static final List SUPPORTED_PROVIDERS = Arrays.asList("aws"); + + private static final Duration CALLBACK_TIMEOUT = Duration.ofMinutes(5); + + private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + + @Nullable + private ServerAddress serverAddress; + + @Nullable + private String connectionLastAccessToken; + + private FallbackState fallbackState = FallbackState.INITIAL; + + @Nullable + private BsonDocument speculativeAuthenticateResponse; + + @Nullable + private Function evaluateChallengeFunction; + + public OidcAuthenticator(final MongoCredentialWithCache credential, + final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { + super(credential, clusterConnectionMode, serverApi); + validateBeforeUse(credential.getCredential()); + + if (getMongoCredential().getAuthenticationMechanism() != MONGODB_OIDC) { + throw new MongoException("Incorrect mechanism: " + getMongoCredential().getMechanism()); + } + } + + @Override + public String getMechanismName() { + return MONGODB_OIDC.getMechanismName(); + } + + @Override + protected SaslClient createSaslClient(final ServerAddress serverAddress) { + this.serverAddress = serverAddress; + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + return new OidcSaslClient(mongoCredentialWithCache); + } + + @Override + @Nullable + public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) { + try { + if (isAutomaticAuthentication()) { + return wrapInSpeculative(prepareAwsTokenFromFileAsJwt()); + } + String cachedAccessToken = getValidCachedAccessToken(); + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + if (cachedAccessToken != null) { + return wrapInSpeculative(prepareTokenAsJwt(cachedAccessToken)); + } else if (mongoCredentialWithCache.getOidcCacheEntry().getIdpInfo() == null) { + String userName = mongoCredentialWithCache.getCredential().getUserName(); + return wrapInSpeculative(prepareUsername(userName)); + } else { + // otherwise, skip speculative auth + return null; + } + } catch (Exception e) { + throw wrapException(e); + } + } + + @NotNull + private BsonDocument wrapInSpeculative(final byte[] outToken) { + BsonDocument startDocument = createSaslStartCommandDocument(outToken) + .append("db", new BsonString(getMongoCredential().getSource())); + appendSaslStartOptions(startDocument); + return startDocument; + } + + @Override + @Nullable + public BsonDocument getSpeculativeAuthenticateResponse() { + BsonDocument response = speculativeAuthenticateResponse; + // response should only be read once + this.speculativeAuthenticateResponse = null; + if (response == null) { + this.connectionLastAccessToken = null; + } + return response; + } + + @Override + public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument response) { + speculativeAuthenticateResponse = response; + } + + @Nullable + private OidcRefreshCallback getRefreshCallback() { + return getMongoCredentialWithCache() + .getCredential() + .getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); + } + + @Nullable + private OidcRequestCallback getRequestCallback() { + return getMongoCredentialWithCache() + .getCredential() + .getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); + } + + @Override + public void reauthenticate(final InternalConnection connection) { + // method must only be called after original handshake: + assertTrue(connection.opened()); + authLock(connection, connection.getDescription()); + } + + @Override + public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { + // method must only be called during original handshake: + assertFalse(connection.opened()); + // this method "wraps" the default authentication method in custom OIDC retry logic + String accessToken = getValidCachedAccessToken(); + if (accessToken != null) { + try { + authenticateUsing(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); + } catch (MongoSecurityException e) { + if (triggersRetry(e)) { + authLock(connection, connectionDescription); + } else { + throw e; + } + } + } else { + authLock(connection, connectionDescription); + } + } + + private static boolean triggersRetry(@Nullable final Throwable t) { + if (t instanceof MongoSecurityException) { + MongoSecurityException e = (MongoSecurityException) t; + Throwable cause = e.getCause(); + if (cause instanceof MongoCommandException) { + MongoCommandException commandCause = (MongoCommandException) cause; + return commandCause.getErrorCode() == 18; + } + } + return false; + } + + private void authenticateUsing( + final InternalConnection connection, + final ConnectionDescription connectionDescription, + final Function evaluateChallengeFunction) { + this.evaluateChallengeFunction = evaluateChallengeFunction; + super.authenticate(connection, connectionDescription); + } + + private void authLock(final InternalConnection connection, final ConnectionDescription connectionDescription) { + fallbackState = FallbackState.INITIAL; + Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { + while (true) { + try { + authenticateUsing(connection, connectionDescription, (challenge) -> evaluate(challenge)); + break; + } catch (MongoSecurityException e) { + if (!(triggersRetry(e) && shouldRetryHandler())) { + throw e; + } + } + } + return null; + }); + } + + private byte[] evaluate(final byte[] challenge) { + if (isAutomaticAuthentication()) { + return prepareAwsTokenFromFileAsJwt(); + } + + OidcRequestCallback requestCallback = assertNotNull(getRequestCallback()); + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + String cachedAccessToken = getValidCachedAccessToken(); + String invalidConnectionAccessToken = connectionLastAccessToken; + String cachedRefreshToken = cacheEntry.getRefreshToken(); + IdpInfo cachedIdpInfo = cacheEntry.getIdpInfo(); + + if (cachedAccessToken != null) { + boolean cachedTokenIsInvalid = cachedAccessToken.equals(invalidConnectionAccessToken); + if (cachedTokenIsInvalid) { + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry.clearAccessToken()); + cachedAccessToken = null; + } + } + OidcRefreshCallback refreshCallback = getRefreshCallback(); + if (cachedAccessToken != null) { + fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; + return prepareTokenAsJwt(cachedAccessToken); + } else if (refreshCallback != null && cachedRefreshToken != null) { + assertNotNull(cachedIdpInfo); + // Invoke Refresh Callback using cached Refresh Token + validateAllowedHosts(getMongoCredential()); + fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; + IdpResponse result = refreshCallback.onRefresh(new OidcRefreshContextImpl( + cachedIdpInfo, cachedRefreshToken, CALLBACK_TIMEOUT)); + return populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); + } else { + // cache is empty + + /* + A check for present idp info short-circuits phase-3a. + + If a challenge is present, it can only be a response to a + "principal-request", so the challenge must be the resulting + idp info. Such a request is made during speculative auth, + though the source is unimportant, as long as we detect and + use it here. + + Checking that the fallback state is not phase-3a ensures that + this does not loop infinitely in the case of a bug. + */ + boolean idpInfoNotPresent = challenge.length == 0; + if (fallbackState != FallbackState.PHASE_3A_PRINCIPAL && idpInfoNotPresent) { + fallbackState = FallbackState.PHASE_3A_PRINCIPAL; + return prepareUsername(mongoCredentialWithCache.getCredential().getUserName()); + } else { + IdpInfo idpInfo = toIdpInfo(challenge); + validateAllowedHosts(getMongoCredential()); + IdpResponse result = requestCallback.onRequest(new OidcRequestContextImpl(idpInfo, CALLBACK_TIMEOUT)); + fallbackState = FallbackState.PHASE_3B_REQUEST_CALLBACK_TOKEN; + return populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); + } + } + } + + private boolean isAutomaticAuthentication() { + return getRequestCallback() == null; + } + + private boolean clientIsComplete() { + return fallbackState != FallbackState.PHASE_3A_PRINCIPAL; + } + + private boolean shouldRetryHandler() { + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { + // a cached access token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken()); + } else if (fallbackState == FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN) { + // a refresh token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + } else { + // a clean-restart failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + return false; + } + return true; + } + + @Nullable + private String getValidCachedAccessToken() { + return getMongoCredentialWithCache() + .getOidcCacheEntry() + .getValidCachedAccessToken(); + } + + static final class OidcCacheEntry { + @Nullable + private final String accessToken; + @Nullable + private final Timeout accessTokenExpiry; + @Nullable + private final String refreshToken; + @Nullable + private final IdpInfo idpInfo; + + @Override + public String toString() { + return "OidcCacheEntry{" + + "\n accessToken#hashCode='" + Objects.hashCode(accessToken) + '\'' + + ",\n accessTokenExpiry=" + accessTokenExpiry + + ",\n refreshToken='" + refreshToken + '\'' + + ",\n idpInfo=" + idpInfo + + '}'; + } + + OidcCacheEntry(final IdpInfo idpInfo, final IdpResponse idpResponse) { + Integer accessTokenExpiresInSeconds = idpResponse.getAccessTokenExpiresInSeconds(); + if (accessTokenExpiresInSeconds != null) { + this.accessToken = idpResponse.getAccessToken(); + long accessTokenExpiryReservedSeconds = TimeUnit.MINUTES.toSeconds(5); + this.accessTokenExpiry = Timeout.startNow( + Math.max(0, accessTokenExpiresInSeconds - accessTokenExpiryReservedSeconds), + TimeUnit.SECONDS); + } else { + this.accessToken = null; + this.accessTokenExpiry = null; + } + String refreshToken = idpResponse.getRefreshToken(); + if (refreshToken != null) { + this.refreshToken = refreshToken; + this.idpInfo = idpInfo; + } else { + this.refreshToken = null; + this.idpInfo = null; + } + } + + OidcCacheEntry() { + this(null, null, null, null); + } + + private OidcCacheEntry(@Nullable final String accessToken, @Nullable final Timeout accessTokenExpiry, + @Nullable final String refreshToken, @Nullable final IdpInfo idpInfo) { + this.accessToken = accessToken; + this.accessTokenExpiry = accessTokenExpiry; + this.refreshToken = refreshToken; + this.idpInfo = idpInfo; + } + + @Nullable + String getValidCachedAccessToken() { + if (accessToken == null || accessTokenExpiry == null || accessTokenExpiry.expired()) { + return null; + } + return accessToken; + } + + @Nullable + String getRefreshToken() { + return refreshToken; + } + + @Nullable + IdpInfo getIdpInfo() { + return idpInfo; + } + + OidcCacheEntry clearAccessToken() { + return new OidcCacheEntry( + null, + null, + this.refreshToken, + this.idpInfo); + } + + OidcCacheEntry clearRefreshToken() { + return new OidcCacheEntry( + this.accessToken, + this.accessTokenExpiry, + null, + null); + } + } + + private final class OidcSaslClient extends SaslClientImpl { + + private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache) { + super(mongoCredentialWithCache.getCredential()); + } + + @Override + public byte[] evaluateChallenge(final byte[] challenge) { + return assertNotNull(evaluateChallengeFunction).apply(challenge); + } + + @Override + public boolean isComplete() { + return clientIsComplete(); + } + + } + + private static String readAwsTokenFromFile() { + String path = System.getenv(AWS_WEB_IDENTITY_TOKEN_FILE); + if (path == null) { + throw new MongoClientException( + format("Environment variable must be specified: %s", AWS_WEB_IDENTITY_TOKEN_FILE)); + } + try { + return new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8); + } catch (IOException e) { + throw new MongoClientException(format( + "Could not read file specified by environment variable: %s at path: %s", + AWS_WEB_IDENTITY_TOKEN_FILE, path), e); + } + } + + private static byte[] prepareUsername(@Nullable final String username) { + BsonDocument document = new BsonDocument(); + if (username != null) { + document = document.append("n", new BsonString(username)); + } + return toBson(document); + } + + private byte[] populateCacheWithCallbackResultAndPrepareJwt( + final IdpInfo serverInfo, + @Nullable final IdpResponse idpResponse) { + if (idpResponse == null) { + throw new MongoConfigurationException("Result of callback must not be null"); + } + OidcCacheEntry newEntry = new OidcCacheEntry(serverInfo, idpResponse); + getMongoCredentialWithCache().setOidcCacheEntry(newEntry); + return prepareTokenAsJwt(idpResponse.getAccessToken()); + } + + private static IdpInfo toIdpInfo(final byte[] challenge) { + BsonDocument c = new RawBsonDocument(challenge); + String issuer = c.getString("issuer").getValue(); + String clientId = c.getString("clientId").getValue(); + return new IdpInfoImpl( + issuer, + clientId, + getStringArray(c, "requestScopes")); + } + + private void validateAllowedHosts(final MongoCredential credential) { + List allowedHosts = assertNotNull(credential.getMechanismProperty(ALLOWED_HOSTS_KEY, DEFAULT_ALLOWED_HOSTS)); + String host = assertNotNull(serverAddress).getHost(); + boolean permitted = allowedHosts.stream().anyMatch(allowedHost -> { + if (allowedHost.startsWith("*.")) { + String ending = allowedHost.substring(1); + return host.endsWith(ending); + } else if (allowedHost.contains("*")) { + throw new IllegalArgumentException( + "Allowed host " + allowedHost + " contains invalid wildcard"); + } else { + return host.equals(allowedHost); + } + }); + if (!permitted) { + throw new MongoSecurityException( + credential, "Host not permitted by " + ALLOWED_HOSTS_KEY + ": " + host); + } + } + + @Nullable + private static List getStringArray(final BsonDocument document, final String key) { + if (!document.isArray(key)) { + return null; + } + return document.getArray(key).stream() + // ignore non-string values from server, rather than error + .filter(v -> v.isString()) + .map(v -> v.asString().getValue()) + .collect(Collectors.toList()); + } + + private byte[] prepareTokenAsJwt(final String accessToken) { + connectionLastAccessToken = accessToken; + return toJwtDocument(accessToken); + } + + private static byte[] prepareAwsTokenFromFileAsJwt() { + String accessToken = readAwsTokenFromFile(); + return toJwtDocument(accessToken); + } + + private static byte[] toJwtDocument(final String accessToken) { + return toBson(new BsonDocument().append("jwt", new BsonString(accessToken))); + } + + /** + * Contains all validation logic for OIDC in one location + */ + public static final class OidcValidator { + private OidcValidator() { + } + + public static void validateOidcCredentialConstruction( + final String source, + final Map mechanismProperties) { + + if (!"$external".equals(source)) { + throw new IllegalArgumentException("source must be '$external'"); + } + + Object providerName = mechanismProperties.get(PROVIDER_NAME_KEY.toLowerCase()); + if (providerName != null) { + if (!(providerName instanceof String) || !SUPPORTED_PROVIDERS.contains(providerName)) { + throw new IllegalArgumentException(PROVIDER_NAME_KEY + " must be one of: " + SUPPORTED_PROVIDERS); + } + } + } + + public static void validateCreateOidcCredential(@Nullable final char[] password) { + if (password != null) { + throw new IllegalArgumentException("password must not be specified for " + + AuthenticationMechanism.MONGODB_OIDC); + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + public static void validateBeforeUse(final MongoCredential credential) { + String userName = credential.getUserName(); + Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); + Object requestCallback = credential.getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); + Object refreshCallback = credential.getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); + if (providerName == null) { + // callback + if (requestCallback == null) { + throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + " or " + + REQUEST_TOKEN_CALLBACK_KEY + " must be specified"); + } + } else { + // automatic + if (userName != null) { + throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + if (requestCallback != null) { + throw new IllegalArgumentException(REQUEST_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + if (refreshCallback != null) { + throw new IllegalArgumentException(REFRESH_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + } + } + } + + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static class OidcRequestContextImpl implements OidcRequestContext { + private final IdpInfo idpInfo; + private final Duration timeout; + + OidcRequestContextImpl(final IdpInfo idpInfo, final Duration timeout) { + this.idpInfo = assertNotNull(idpInfo); + this.timeout = assertNotNull(timeout); + } + + @Override + public IdpInfo getIdpInfo() { + return idpInfo; + } + + @Override + public Duration getTimeout() { + return timeout; + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static final class OidcRefreshContextImpl extends OidcRequestContextImpl + implements OidcRefreshContext { + private final String refreshToken; + + OidcRefreshContextImpl(final IdpInfo idpInfo, final String refreshToken, + final Duration timeout) { + super(idpInfo, timeout); + this.refreshToken = assertNotNull(refreshToken); + } + + @Override + public String getRefreshToken() { + return refreshToken; + } + } + + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static final class IdpInfoImpl implements IdpInfo { + private final String issuer; + private final String clientId; + + private final List requestScopes; + + IdpInfoImpl(final String issuer, final String clientId, @Nullable final List requestScopes) { + this.issuer = assertNotNull(issuer); + this.clientId = assertNotNull(clientId); + this.requestScopes = requestScopes == null + ? Collections.emptyList() + : Collections.unmodifiableList(requestScopes); + } + + @Override + public String getIssuer() { + return issuer; + } + + @Override + public String getClientId() { + return clientId; + } + + @Override + public List getRequestScopes() { + return requestScopes; + } + } + + /** + * Represents what was sent in the last request to the MongoDB server. + */ + private enum FallbackState { + INITIAL, + PHASE_1_CACHED_TOKEN, + PHASE_2_REFRESH_CALLBACK_TOKEN, + PHASE_3A_PRINCIPAL, + PHASE_3B_REQUEST_CALLBACK_TOKEN + } +} diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index ebcf81c8532..e399b00bea8 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -16,6 +16,8 @@ package com.mongodb.internal.connection; +import com.mongodb.AuthenticationMechanism; +import com.mongodb.MongoCredential; import com.mongodb.MongoException; import com.mongodb.MongoInterruptedException; import com.mongodb.MongoSecurityException; @@ -30,9 +32,13 @@ import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; import org.bson.BsonBinary; +import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.EncoderContext; +import org.bson.io.BasicOutputBuffer; import javax.security.auth.Subject; import javax.security.auth.login.LoginException; @@ -55,6 +61,7 @@ abstract class SaslAuthenticator extends Authenticator implements SpeculativeAut super(credential, clusterConnectionMode, serverApi); } + @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { doAsSubject(() -> { SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress()); @@ -121,9 +128,11 @@ private void throwIfSaslClientIsNull(@Nullable final SaslClient saslClient) { } private BsonDocument getNextSaslResponse(final SaslClient saslClient, final InternalConnection connection) { - BsonDocument response = getSpeculativeAuthenticateResponse(); - if (response != null) { - return response; + if (!connection.opened()) { + BsonDocument response = getSpeculativeAuthenticateResponse(); + if (response != null) { + return response; + } } try { @@ -136,9 +145,9 @@ private BsonDocument getNextSaslResponse(final SaslClient saslClient, final Inte private void getNextSaslResponseAsync(final SaslClient saslClient, final InternalConnection connection, final SingleResultCallback callback) { - BsonDocument response = getSpeculativeAuthenticateResponse(); SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); try { + BsonDocument response = getSpeculativeAuthenticateResponse(); if (response == null) { byte[] serverResponse = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null); sendSaslStartAsync(serverResponse, connection, (result, t) -> { @@ -280,6 +289,15 @@ void doAsSubject(final java.security.PrivilegedAction action) { } } + static byte[] toBson(final BsonDocument document) { + byte[] bytes; + BasicOutputBuffer buffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); + bytes = new byte[buffer.size()]; + System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); + return bytes; + } + private final class Continuator implements SingleResultCallback { private final SaslClient saslClient; private final BsonDocument saslStartDocument; @@ -331,8 +349,52 @@ private void continueConversation(final BsonDocument result) { disposeOfSaslClient(saslClient); } } - } + protected abstract static class SaslClientImpl implements SaslClient { + private final MongoCredential credential; + + protected SaslClientImpl(final MongoCredential credential) { + this.credential = credential; + } + + @Override + public boolean hasInitialResponse() { + return true; + } + + @Override + public byte[] unwrap(final byte[] bytes, final int i, final int i1) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public byte[] wrap(final byte[] bytes, final int i, final int i1) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public Object getNegotiatedProperty(final String s) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public void dispose() { + // nothing to do + } + + @Override + public final String getMechanismName() { + AuthenticationMechanism authMechanism = getCredential().getAuthenticationMechanism(); + if (authMechanism == null) { + throw new IllegalArgumentException("Authentication mechanism cannot be null"); + } + return authMechanism.getMechanismName(); + } + + protected final MongoCredential getCredential() { + return credential; + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java index 5dec0d90c1e..02bc7912c93 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java @@ -92,7 +92,7 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { if (speculativeSaslClient != null) { return speculativeSaslClient; } - return new ScramShaSaslClient(getMongoCredentialWithCache(), randomStringGenerator, authenticationHashGenerator); + return new ScramShaSaslClient(getMongoCredentialWithCache().getCredential(), randomStringGenerator, authenticationHashGenerator); } @Override @@ -122,9 +122,7 @@ public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument resp } } - class ScramShaSaslClient implements SaslClient { - - private final MongoCredentialWithCache credential; + class ScramShaSaslClient extends SaslClientImpl { private final RandomStringGenerator randomStringGenerator; private final AuthenticationHashGenerator authenticationHashGenerator; private final String hAlgorithm; @@ -136,9 +134,11 @@ class ScramShaSaslClient implements SaslClient { private byte[] serverSignature; private int step = -1; - ScramShaSaslClient(final MongoCredentialWithCache credential, final RandomStringGenerator randomStringGenerator, - final AuthenticationHashGenerator authenticationHashGenerator) { - this.credential = credential; + ScramShaSaslClient( + final MongoCredential credential, + final RandomStringGenerator randomStringGenerator, + final AuthenticationHashGenerator authenticationHashGenerator) { + super(credential); this.randomStringGenerator = randomStringGenerator; this.authenticationHashGenerator = authenticationHashGenerator; if (assertNotNull(credential.getAuthenticationMechanism()).equals(SCRAM_SHA_1)) { @@ -150,14 +150,6 @@ class ScramShaSaslClient implements SaslClient { } } - public String getMechanismName() { - return assertNotNull(credential.getAuthenticationMechanism()).getMechanismName(); - } - - public boolean hasInitialResponse() { - return true; - } - public byte[] evaluateChallenge(final byte[] challenge) throws SaslException { step++; if (step == 0) { @@ -167,7 +159,8 @@ public byte[] evaluateChallenge(final byte[] challenge) throws SaslException { } else if (step == 2) { return validateServerSignature(challenge); } else { - throw new SaslException(format("Too many steps involved in the %s negotiation.", getMechanismName())); + throw new SaslException(format("Too many steps involved in the %s negotiation.", + super.getMechanismName())); } } @@ -184,22 +177,6 @@ public boolean isComplete() { return step == 2; } - public byte[] unwrap(final byte[] incoming, final int offset, final int len) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public byte[] wrap(final byte[] outgoing, final int offset, final int len) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public Object getNegotiatedProperty(final String propName) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public void dispose() { - // nothing to do - } - private byte[] computeClientFirstMessage() { clientNonce = randomStringGenerator.generate(RANDOM_LENGTH); String clientFirstMessage = "n=" + getUserName() + ",r=" + clientNonce; @@ -318,9 +295,8 @@ private HashMap parseServerResponse(final String response) { return map; } - private String getUserName() { - String userName = credential.getCredential().getUserName(); + String userName = getCredential().getUserName(); if (userName == null) { throw new IllegalArgumentException("Username can not be null"); } @@ -328,8 +304,8 @@ private String getUserName() { } private String getAuthenicationHash() { - String password = authenticationHashGenerator.generate(credential.getCredential()); - if (credential.getAuthenticationMechanism() == SCRAM_SHA_256) { + String password = authenticationHashGenerator.generate(getCredential()); + if (getCredential().getAuthenticationMechanism() == SCRAM_SHA_256) { password = SaslPrep.saslPrepStored(password); } return password; diff --git a/driver-core/src/test/functional/com/mongodb/client/TestHelper.java b/driver-core/src/test/functional/com/mongodb/client/TestHelper.java new file mode 100644 index 00000000000..237c03c7e19 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/TestHelper.java @@ -0,0 +1,47 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.lang.Nullable; + +import java.lang.reflect.Field; +import java.util.Map; + +import static java.lang.System.getenv; + +public final class TestHelper { + + public static void setEnvironmentVariable(final String name, @Nullable final String value) { + try { + Map env = getenv(); + Field field = env.getClass().getDeclaredField("m"); + field.setAccessible(true); + @SuppressWarnings("unchecked") + Map result = (Map) field.get(env); + if (value == null) { + result.remove(name); + } else { + result.put(name, value); + } + } catch (IllegalAccessException | NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + + private TestHelper() { + } +} diff --git a/driver-core/src/test/functional/com/mongodb/client/TestListener.java b/driver-core/src/test/functional/com/mongodb/client/TestListener.java new file mode 100644 index 00000000000..db68065432c --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/TestListener.java @@ -0,0 +1,43 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.annotations.ThreadSafe; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A simple listener that consumes string events, which can be checked in tests. + */ +@ThreadSafe +public final class TestListener { + private final List events = Collections.synchronizedList(new ArrayList<>()); + + public void add(final String s) { + events.add(s); + } + + public List getEventStrings() { + return new ArrayList<>(events); + } + + public void clear() { + events.clear(); + } +} diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java index fcdbeccc420..d7bf8529b48 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java @@ -18,11 +18,13 @@ import com.mongodb.MongoInterruptedException; import com.mongodb.MongoTimeoutException; +import com.mongodb.client.TestListener; import com.mongodb.event.CommandEvent; import com.mongodb.event.CommandFailedEvent; import com.mongodb.event.CommandListener; import com.mongodb.event.CommandStartedEvent; import com.mongodb.event.CommandSucceededEvent; +import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonDocumentWriter; import org.bson.BsonDouble; @@ -55,6 +57,8 @@ public class TestCommandListener implements CommandListener { private final List eventTypes; private final List ignoredCommandMonitoringEvents; private final List events = new ArrayList<>(); + @Nullable + private final TestListener listener; private final Lock lock = new ReentrantLock(); private final Condition commandCompletedCondition = lock.newCondition(); private final boolean observeSensitiveCommands; @@ -76,25 +80,44 @@ public Codec get(final Class clazz, final CodecRegistry registry) { }); } + /** + * When a test listener is set, this command listener will send string events to the + * test listener in the form {@code " "}, where the event + * type will be lowercase and will omit the terms "command" and "event". + * For example: {@code "saslContinue succeeded"}. + * + * @see InternalStreamConnection#setRecordEverything(boolean) + * @param listener the test listener + */ + public TestCommandListener(final TestListener listener) { + this(Arrays.asList("commandStartedEvent", "commandSucceededEvent", "commandFailedEvent"), emptyList(), true, listener); + } + public TestCommandListener() { this(Arrays.asList("commandStartedEvent", "commandSucceededEvent", "commandFailedEvent"), emptyList()); } public TestCommandListener(final List eventTypes, final List ignoredCommandMonitoringEvents) { - this(eventTypes, ignoredCommandMonitoringEvents, true); + this(eventTypes, ignoredCommandMonitoringEvents, true, null); } public TestCommandListener(final List eventTypes, final List ignoredCommandMonitoringEvents, - final boolean observeSensitiveCommands) { + final boolean observeSensitiveCommands, @Nullable final TestListener listener) { this.eventTypes = eventTypes; this.ignoredCommandMonitoringEvents = ignoredCommandMonitoringEvents; this.observeSensitiveCommands = observeSensitiveCommands; + this.listener = listener; } + + public void reset() { lock.lock(); try { events.clear(); + if (listener != null) { + listener.clear(); + } } finally { lock.unlock(); } @@ -109,6 +132,18 @@ public List getEvents() { } } + private void addEvent(final CommandEvent c) { + events.add(c); + String className = c.getClass().getSimpleName() + .replace("Command", "") + .replace("Event", "") + .toLowerCase(); + // example: "saslContinue succeeded" + if (listener != null) { + listener.add(c.getCommandName() + " " + className); + } + } + public CommandStartedEvent getCommandStartedEvent(final String commandName) { for (CommandEvent event : getCommandStartedEvents()) { if (event instanceof CommandStartedEvent) { @@ -226,7 +261,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(new CommandStartedEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), + addEvent(new CommandStartedEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), event.getConnectionDescription(), event.getDatabaseName(), event.getCommandName(), event.getCommand() == null ? null : getWritableClone(event.getCommand()))); } finally { @@ -249,7 +284,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(new CommandSucceededEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), + addEvent(new CommandSucceededEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), event.getConnectionDescription(), event.getCommandName(), event.getResponse() == null ? null : event.getResponse().clone(), event.getElapsedTime(TimeUnit.NANOSECONDS))); @@ -274,7 +309,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(event); + addEvent(event); commandCompletedCondition.signal(); } finally { lock.unlock(); diff --git a/driver-core/src/test/resources/auth/connection-string.json b/driver-core/src/test/resources/auth/legacy/connection-string.json similarity index 74% rename from driver-core/src/test/resources/auth/connection-string.json rename to driver-core/src/test/resources/auth/legacy/connection-string.json index 2a37ae8df47..1d69685df10 100644 --- a/driver-core/src/test/resources/auth/connection-string.json +++ b/driver-core/src/test/resources/auth/legacy/connection-string.json @@ -444,6 +444,147 @@ "AWS_SESSION_TOKEN": "token!@#$%^&*()_+" } } + }, + { + "description": "should recognise the mechanism and request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest", "oidcRefresh"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true, + "REFRESH_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": "principalName", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "PROVIDER_NAME": "aws" + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "PROVIDER_NAME": "aws" + } + } + }, + { + "description": "should throw an exception if username and password are specified (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:gcp", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:unexisted", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRefresh"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if provider name and request callback are specified", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "callback": ["oidcRequest"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if provider name and refresh callback are specified", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "callback": ["oidcRefresh"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", + "valid": false, + "credential": null } ] -} +} \ No newline at end of file diff --git a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json new file mode 100644 index 00000000000..c99ebc6ece2 --- /dev/null +++ b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json @@ -0,0 +1,191 @@ +{ + "description": "reauthenticate_with_retry", + "schemaVersion": "1.12", + "runOnRequirements": [ + { + "minServerVersion": "6.3", + "auth": true + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": true, + "retryWrites": true + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "db", + "documents": [] + } + ], + "tests": [ + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": {} + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json new file mode 100644 index 00000000000..799057bf74f --- /dev/null +++ b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json @@ -0,0 +1,191 @@ +{ + "description": "reauthenticate_without_retry", + "schemaVersion": "1.12", + "runOnRequirements": [ + { + "minServerVersion": "6.3", + "auth": true + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": false, + "retryWrites": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "db", + "documents": [] + } + ], + "tests": [ + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": {} + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java index dfb81ba8de4..7f4acab857d 100644 --- a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java +++ b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java @@ -16,9 +16,13 @@ package com.mongodb; +import com.mongodb.internal.connection.OidcAuthenticator; +import com.mongodb.lang.Nullable; import junit.framework.TestCase; +import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonNull; +import org.bson.BsonString; import org.bson.BsonValue; import org.junit.Test; import org.junit.runner.RunWith; @@ -32,7 +36,11 @@ import java.util.Collection; import java.util.List; -// See https://github.com/mongodb/specifications/tree/master/source/auth/tests +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; +import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; + +// See https://github.com/mongodb/specifications/tree/master/source/auth/legacy/tests @RunWith(Parameterized.class) public class AuthConnectionStringTest extends TestCase { private final String input; @@ -56,7 +64,7 @@ public void shouldPassAllOutcomes() { @Parameterized.Parameters(name = "{1}") public static Collection data() throws URISyntaxException, IOException { List data = new ArrayList<>(); - for (File file : JsonPoweredTestHelper.getTestFiles("/auth")) { + for (File file : JsonPoweredTestHelper.getTestFiles("/auth/legacy")) { BsonDocument testDocument = JsonPoweredTestHelper.getTestDocument(file); for (BsonValue test : testDocument.getArray("tests")) { data.add(new Object[]{file.getName(), test.asDocument().getString("description").getValue(), @@ -69,7 +77,7 @@ public static Collection data() throws URISyntaxException, IOException private void testInvalidUris() { Throwable expectedError = null; try { - new ConnectionString(input).getCredential(); + getMongoCredential(); } catch (Throwable t) { expectedError = t; } @@ -78,7 +86,7 @@ private void testInvalidUris() { } private void testValidUris() { - MongoCredential credential = new ConnectionString(input).getCredential(); + MongoCredential credential = getMongoCredential(); if (credential != null) { assertString("credential.source", credential.getSource()); @@ -99,6 +107,36 @@ private void testValidUris() { } } + @Nullable + private MongoCredential getMongoCredential() { + ConnectionString connectionString; + connectionString = new ConnectionString(input); + MongoCredential credential = connectionString.getCredential(); + if (credential != null) { + BsonArray callbacks = (BsonArray) getExpectedValue("callback"); + if (callbacks != null) { + for (BsonValue v : callbacks) { + String string = ((BsonString) v).getValue(); + if ("oidcRequest".equals(string)) { + credential = credential.withMechanismProperty( + REQUEST_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRequestCallback) (context) -> null); + } else if ("oidcRefresh".equals(string)) { + credential = credential.withMechanismProperty( + REFRESH_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRefreshCallback) (context) -> null); + } else { + fail("Unsupported callback: " + string); + } + } + } + if (MONGODB_OIDC.getMechanismName().equals(credential.getMechanism())) { + OidcAuthenticator.OidcValidator.validateBeforeUse(credential); + } + } + return credential; + } + private void assertString(final String key, final String actual) { BsonValue expected = getExpectedValue(key); @@ -142,6 +180,14 @@ private void assertMechanismProperties(final MongoCredential credential) { } } else if ((document.get(key).isBoolean())) { boolean expectedValue = document.getBoolean(key).getValue(); + if (REQUEST_TOKEN_CALLBACK_KEY.equals(key)) { + assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRequestCallback); + return; + } + if (REFRESH_TOKEN_CALLBACK_KEY.equals(key)) { + assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRefreshCallback); + return; + } assertNotNull(actualMechanismProperty); assertEquals(expectedValue, actualMechanismProperty); } else { diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java index 3bc32e26c26..69174462f77 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java @@ -373,7 +373,8 @@ private void initClient(final BsonDocument entity, final String id, TestCommandListener testCommandListener = new TestCommandListener( entity.getArray("observeEvents").stream() .map(type -> type.asString().getValue()).collect(Collectors.toList()), - ignoreCommandMonitoringEvents, entity.getBoolean("observeSensitiveCommands", BsonBoolean.FALSE).getValue()); + ignoreCommandMonitoringEvents, entity.getBoolean("observeSensitiveCommands", BsonBoolean.FALSE).getValue(), + null); clientSettingsBuilder.addCommandListener(testCommandListener); putEntity(id + "-command-listener", testCommandListener, clientCommandListeners); diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java new file mode 100644 index 00000000000..f94977f2546 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.unified; + +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Collection; + +public class UnifiedAuthTest extends UnifiedSyncTest { + public UnifiedAuthTest(@SuppressWarnings("unused") final String fileDescription, + @SuppressWarnings("unused") final String testDescription, + final String schemaVersion, final BsonArray runOnRequirements, final BsonArray entitiesArray, + final BsonArray initialData, final BsonDocument definition) { + super(schemaVersion, runOnRequirements, entitiesArray, initialData, definition); + } + + @Parameterized.Parameters(name = "{0}: {1}") + public static Collection data() throws URISyntaxException, IOException { + return getTestData("unified-test-format/auth"); + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java new file mode 100644 index 00000000000..74e95d7e253 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -0,0 +1,937 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCommandException; +import com.mongodb.MongoConfigurationException; +import com.mongodb.MongoCredential; +import com.mongodb.MongoCredential.IdpResponse; +import com.mongodb.MongoCredential.OidcRefreshCallback; +import com.mongodb.MongoSecurityException; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.TestListener; +import com.mongodb.event.CommandListener; +import com.mongodb.lang.Nullable; +import org.bson.BsonArray; +import org.bson.BsonBoolean; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonString; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.opentest4j.AssertionFailedError; +import org.opentest4j.MultipleFailuresError; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.IdpInfo; +import static com.mongodb.MongoCredential.OidcRefreshContext; +import static com.mongodb.MongoCredential.OidcRequestCallback; +import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; +import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.createOidcCredential; +import static com.mongodb.client.TestHelper.setEnvironmentVariable; +import static java.lang.System.getenv; +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static util.ThreadTestHelpers.executeAll; + + +/** + * See + * Prose Tests. + */ +public class OidcAuthenticationProseTests { + + public static boolean oidcTestsEnabled() { + return Boolean.parseBoolean(getenv().get("OIDC_TESTS_ENABLED")); + } + + private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + + public static final String TOKEN_DIRECTORY = "/tmp/tokens/"; // TODO-OIDC + + protected static final String OIDC_URL = "mongodb://localhost/?authMechanism=MONGODB-OIDC"; + private static final String AWS_OIDC_URL = + "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws"; + private String appName; + + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return MongoClients.create(settings); + } + + protected void setOidcFile(final String file) { + setEnvironmentVariable(AWS_WEB_IDENTITY_TOKEN_FILE, TOKEN_DIRECTORY + file); + } + + @BeforeEach + public void beforeEach() { + assumeTrue(oidcTestsEnabled()); + // In each test, clearing the cache is not required, since there is no global cache + setOidcFile("test_user1"); + InternalStreamConnection.setRecordEverything(true); + this.appName = this.getClass().getSimpleName() + "-" + new Random().nextInt(Integer.MAX_VALUE); + } + + @AfterEach + public void afterEach() { + InternalStreamConnection.setRecordEverything(false); + } + + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // 1.1 to 1.5: + "test1p1 # test_user1 # " + OIDC_URL, + "test1p2 # test_user1 # mongodb://test_user1@localhost/?authMechanism=MONGODB-OIDC", + "test1p3 # test_user1 # mongodb://test_user1@localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", + "test1p4 # test_user2 # mongodb://test_user2@localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", + "test1p5 # invalid # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", + }) + public void test1CallbackDrivenAuth(final String name, final String file, final String url) { + boolean shouldPass = !file.equals("invalid"); + setOidcFile(file); + // #. Create a request callback that returns a valid token. + OidcRequestCallback onRequest = createCallback(); + // #. Create a client with a URL of the form ... and the OIDC request callback. + MongoClientSettings clientSettings = createSettings(url, onRequest, null); + // #. Perform a find operation that succeeds / fails + if (shouldPass) { + performFind(clientSettings); + } else { + performFind( + clientSettings, + MongoCommandException.class, + "Command failed with error 18 (AuthenticationFailed)"); + } + } + + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // 1.6, both variants: + "'' # " + OIDC_URL, + "example.com # mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com", + }) + public void test1p6CallbackDrivenAuthAllowedHostsBlocked(final String allowedHosts, final String url) { + // Create a client that uses the OIDC url and a request callback, and an ALLOWED_HOSTS that contains... + List allowedHostsList = asList(allowedHosts.split(",")); + MongoClientSettings settings = createSettings(url, createCallback(), null, allowedHostsList, null); + // #. Assert that a find operation fails with a client-side error. + performFind(settings, MongoSecurityException.class, ""); + } + + @Test + public void test1p7LockAvoidsExtraCallbackCalls() { + proveThatConcurrentCallbacksThrow(); + // The test requires that two operations are attempted concurrently. + // The delay on the next find should cause the initial request to delay + // and the ensuing refresh to block, rather than entering onRefresh. + // After blocking, this ensuing refresh thread will enter onRefresh. + AtomicInteger concurrent = new AtomicInteger(); + TestCallback onRequest = createCallback().setExpired().setConcurrentTracker(concurrent); + TestCallback onRefresh = createCallback().setConcurrentTracker(concurrent); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + } + } + + public void proveThatConcurrentCallbacksThrow() { + // ensure that, via delay, test callbacks throw when invoked concurrently + AtomicInteger c = new AtomicInteger(); + TestCallback request = createCallback().setConcurrentTracker(c).setDelayMs(5); + TestCallback refresh = createCallback().setConcurrentTracker(c); + IdpInfo serverInfo = new OidcAuthenticator.IdpInfoImpl("issuer", "clientId", asList()); + executeAll(() -> { + sleep(2); + assertThrows(RuntimeException.class, () -> { + refresh.onRefresh(new OidcAuthenticator.OidcRefreshContextImpl(serverInfo, "refToken", Duration.ofSeconds(1234))); + }); + }, () -> { + request.onRequest(new OidcAuthenticator.OidcRequestContextImpl(serverInfo, Duration.ofSeconds(1234))); + }); + } + + private void sleep(final long ms) { + try { + Thread.sleep(ms); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // 2.1 to 2.3: + "test2p1 # test_user1 # " + AWS_OIDC_URL, + "test2p2 # test_user1 # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws&directConnection=true&readPreference=secondaryPreferred", + "test2p3 # test_user2 # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws&directConnection=true&readPreference=secondaryPreferred", + }) + public void test2AwsAutomaticAuth(final String name, final String file, final String url) { + setOidcFile(file); + // #. Create a client with a url of the form ... + MongoCredential credential = createOidcCredential(null) + .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); + MongoClientSettings clientSettings = MongoClientSettings.builder() + .applicationName(appName) + .credential(credential) + .applyConnectionString(new ConnectionString(url)) + .build(); + // #. Perform a find operation that succeeds. + performFind(clientSettings); + } + + @Test + public void test2p4AllowedHostsIgnored() { + MongoClientSettings settings = createSettings( + AWS_OIDC_URL, null, null, Arrays.asList(), null); + performFind(settings); + } + + @Test + public void test3p1ValidCallbacks() { + String connectionString = "mongodb://test_user1@localhost/?authMechanism=MONGODB-OIDC"; + String expectedClientId = "0oadp0hpl7q3UIehP297"; + String expectedIssuer = "https://ebgxby0dw8.execute-api.us-west-1.amazonaws.com/default/mock-identity-config-oidc"; + Duration expectedSeconds = Duration.ofMinutes(5); + + TestCallback onRequest = createCallback().setExpired(); + TestCallback onRefresh = createCallback(); + // #. Verify that the request callback was called with the appropriate + // inputs, including the timeout parameter if possible. + // #. Verify that the refresh callback was called with the appropriate + // inputs, including the timeout parameter if possible. + OidcRequestCallback onRequest2 = (context) -> { + assertEquals(expectedClientId, context.getIdpInfo().getClientId()); + assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); + assertEquals(Arrays.asList(), context.getIdpInfo().getRequestScopes()); + assertEquals(expectedSeconds, context.getTimeout()); + return onRequest.onRequest(context); + }; + OidcRefreshCallback onRefresh2 = (context) -> { + assertEquals(expectedClientId, context.getIdpInfo().getClientId()); + assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); + assertEquals(Arrays.asList(), context.getIdpInfo().getRequestScopes()); + assertEquals(expectedSeconds, context.getTimeout()); + assertEquals("refreshToken", context.getRefreshToken()); + return onRefresh.onRefresh(context); + }; + MongoClientSettings clientSettings = createSettings(connectionString, onRequest2, onRefresh2); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + // Ensure that both callbacks were called + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + } + } + + @Test + public void test3p2RequestCallbackReturnsNull() { + //noinspection ConstantConditions + OidcRequestCallback onRequest = (context) -> null; + MongoClientSettings settings = this.createSettings(OIDC_URL, onRequest, null); + performFind(settings, MongoConfigurationException.class, "Result of callback must not be null"); + } + + @Test + public void test3p3RefreshCallbackReturnsNull() { + TestCallback onRequest = createCallback().setExpired().setDelayMs(100); + //noinspection ConstantConditions + OidcRefreshCallback onRefresh = (context) -> null; + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + try { + executeAll(2, () -> performFind(mongoClient)); + } catch (MultipleFailuresError actual) { + assertEquals(1, actual.getFailures().size()); + assertCause( + MongoConfigurationException.class, + "Result of callback must not be null", + actual.getFailures().get(0)); + } + assertEquals(1, onRequest.getInvocations()); + } + } + + @Test + public void test3p4RequestCallbackReturnsInvalidData() { + // #. Create a client with a request callback that returns data not + // conforming to the OIDCRequestTokenResult with missing field(s). + // #. ... with extra field(s). - not possible + OidcRequestCallback onRequest = (context) -> { + //noinspection ConstantConditions + return new IdpResponse(null, null, null); + }; + // we ensure that the error is propagated + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + try { + performFind(mongoClient); + fail(); + } catch (Exception e) { + assertCause(IllegalArgumentException.class, "accessToken can not be null", e); + } + } + } + + @Test + public void test3p5RefreshCallbackReturnsInvalidData() { + TestCallback onRequest = createCallback().setExpired(); + OidcRefreshCallback onRefresh = (context) -> { + //noinspection ConstantConditions + return new IdpResponse(null, null, null); + }; + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + try { + executeAll(2, () -> performFind(mongoClient)); + } catch (MultipleFailuresError actual) { + assertEquals(1, actual.getFailures().size()); + assertCause( + IllegalArgumentException.class, + "accessToken can not be null", + actual.getFailures().get(0)); + } + assertEquals(1, onRequest.getInvocations()); + } + } + + // 3.6 Refresh Callback Returns Extra Data - not possible due to use of class + + @Test + public void test4p1CachedCredentialsCacheWithRefresh() { + // #. Create a new client with a request callback that gives credentials that expire in one minute. + TestCallback onRequest = createCallback().setExpired(); + TestCallback onRefresh = createCallback(); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Create a new client with the same request callback and a refresh callback. + // Instead: + // 1. Delay the first find, causing the second find to authenticate a second connection + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + // #. Ensure that a find operation adds credentials to the cache. + // #. Ensure that a find operation results in a call to the refresh callback. + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + // the refresh invocation will fail if the cached tokens are null + // so a success implies that credentials were present in the cache + } + } + + @Test + public void test4p2CachedCredentialsCacheWithNoRefresh() { + // #. Create a new client with a request callback that gives credentials that expire in one minute. + // #. Ensure that a find operation adds credentials to the cache. + // #. Create a new client with a request callback but no refresh callback. + // #. Ensure that a find operation results in a call to the request callback. + TestCallback onRequest = createCallback().setExpired(); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + // test is the same as 4.1, but no onRefresh, and assert that the onRequest is called twice + assertEquals(2, onRequest.getInvocations()); + } + } + + // 4.3 Cache key includes callback - skipped: + // If the driver does not support using callback references or hashes as part of the cache key, skip this test. + + @Test + public void test4p4ErrorClearsCache() { + // #. Create a new client with a valid request callback that + // gives credentials that expire within 5 minutes and + // a refresh callback that gives invalid credentials. + + TestListener listener = new TestListener(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires", + "test_user1_1"); + TestCallback onRequest = createCallback() + .setExpired() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + TestCallback onRefresh = createCallback() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Ensure that a find operation adds a new entry to the cache. + performFind(mongoClient); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "onRequest invoked", + "read access token: test_user1", + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that a subsequent find operation results in a 391 error. + failCommand(391, 1, "find"); + // ensure that the operation entirely fails, after attempting both potential fallback callbacks + assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); + assertEquals(Arrays.asList( + "find started", + "find failed", + "onRefresh invoked", + "read access token: test_user1_expires", + "saslStart started", + "saslStart failed", + // falling back to principal request, onRequest callback. + "saslStart started", + "saslStart succeeded", + "onRequest invoked", + "read access token: test_user1_expires", + "saslContinue started", + "saslContinue failed" + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that the cache value cleared. + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(Arrays.asList( + "find started", + "find failed", + // falling back to principal request, onRequest callback. + // this implies that the cache has been cleared during the + // preceding find operation. + "saslStart started", + "saslStart succeeded", + "onRequest invoked", + "read access token: test_user1_1", + "saslContinue started", + "saslContinue succeeded", + // auth has finished + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + } + } + + // not a prose test. + @Test + public void testEventListenerMustNotLogReauthentication() { + InternalStreamConnection.setRecordEverything(false); + + TestListener listener = new TestListener(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires", + "test_user1_1"); + TestCallback onRequest = createCallback() + .setExpired() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + TestCallback onRefresh = createCallback() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + assertEquals(Arrays.asList( + "onRequest invoked", + "read access token: test_user1", + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + + failCommand(391, 1, "find"); + assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); + assertEquals(Arrays.asList( + "find started", + "find failed", + "onRefresh invoked", + "read access token: test_user1_expires", + // falling back to principal request, onRequest callback + "onRequest invoked", + "read access token: test_user1_expires" + ), listener.getEventStrings()); + } + } + + @Test + public void test4p5AwsAutomaticWorkflowDoesNotUseCache() { + // #. Create a new client that uses the AWS automatic workflow. + // #. Ensure that a find operation does not add credentials to the cache. + setOidcFile("test_user1"); + MongoCredential credential = createOidcCredential(null) + .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); + ConnectionString connectionString = new ConnectionString(AWS_OIDC_URL); + MongoClientSettings clientSettings = MongoClientSettings.builder() + .applicationName(appName) + .credential(credential) + .applyConnectionString(connectionString) + .build(); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + // This ensures that the next find failure results in a file (rather than cache) read + failCommand(391, 1, "find"); + setOidcFile("invalid_file"); + assertCause(NoSuchFileException.class, "invalid_file", () -> performFind(mongoClient)); + } + } + + @Test + public void test5SpeculativeAuthentication() { + // #. We can only test the successful case, by verifying that saslStart is not called. + // #. Create a client with a request callback that returns a valid token that will not expire soon. + TestListener listener = new TestListener(); + TestCallback onRequest = createCallback().setEventListener(listener); + TestCommandListener commandListener = new TestCommandListener(listener); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // instead of setting failpoints for saslStart, we inspect events + delayNextFind(); + executeAll(2, () -> performFind(mongoClient)); + + List events = listener.getEventStrings(); + assertFalse(events.stream().anyMatch(e -> e.contains("saslStart"))); + // onRequest is 2-step, so we expect 2 continues + assertEquals(2, events.stream().filter(e -> e.contains("saslContinue started")).count()); + // confirm all commands are enabled + assertTrue(events.stream().anyMatch(e -> e.contains("isMaster started"))); + } + } + + // Not a prose test + @Test + public void testAutomaticAuthUsesSpeculative() { + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings settings = createSettings( + AWS_OIDC_URL, null, null, Arrays.asList(), commandListener); + try (MongoClient mongoClient = createMongoClient(settings)) { + // we use a listener instead of a failpoint + performFind(mongoClient); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + } + } + + @Test + public void test6p1ReauthenticationSucceeds() { + // #. Create request and refresh callbacks that return valid credentials that will not expire soon. + TestListener listener = new TestListener(); + TestCallback onRequest = createCallback().setEventListener(listener); + TestCallback onRefresh = createCallback().setEventListener(listener); + + // #. Create a client with the callbacks and an event listener capable of listening for SASL commands. + TestCommandListener commandListener = new TestCommandListener(listener); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + + // #. Perform a find operation that succeeds. + performFind(mongoClient); + + // #. Assert that the refresh callback has not been called. + assertEquals(0, onRefresh.getInvocations()); + + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "onRequest invoked", + "read access token: test_user1", + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + + // #. Clear the listener state if possible. + commandListener.reset(); + listener.clear(); + + // #. Force a reauthenication using a failCommand + failCommand(391, 1, "find"); + + // #. Perform another find operation that succeeds. + performFind(mongoClient); + + // #. Assert that the ordering of command started events is: find, find. + // #. Assert that the ordering of command succeeded events is: find. + // #. Assert that a find operation failed once during the command execution. + assertEquals(Arrays.asList( + "find started", + "find failed", + "onRefresh invoked", + "read access token: test_user1", + "saslStart started", + "saslStart succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + + // #. Assert that the refresh callback has been called once, if possible. + assertEquals(1, onRefresh.getInvocations()); + } + } + + @NotNull + private ConcurrentLinkedQueue tokenQueue(final String... queue) { + return Stream + .of(queue) + .map(v -> TOKEN_DIRECTORY + v) + .collect(Collectors.toCollection(ConcurrentLinkedQueue::new)); + } + + @Test + public void test6p2ReauthenticationRetriesAndSucceedsWithCache() { + // #. Create request and refresh callbacks that return valid credentials that will not expire soon. + TestCallback onRequest = createCallback(); + TestCallback onRefresh = createCallback(); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Perform a find operation that succeeds. + performFind(mongoClient); + // #. Force a reauthenication using a failCommand + failCommand(391, 1, "find"); + // #. Perform a find operation that succeeds. + performFind(mongoClient); + } + } + + // 6.3 Retries and Fails with no Cache + // Appears to be untestable, since it requires 391 failure on jwt (may be fixed in future spec) + + @Test + public void test6p4SeparateConnectionsAvoidExtraCallbackCalls() { + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_1"); + TestCallback onRequest = createCallback().setPathSupplier(() -> tokens.remove()); + TestCallback onRefresh = createCallback().setPathSupplier(() -> tokens.remove()); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Peform a find operation on each ... that succeeds. + delayNextFind(); + executeAll(2, () -> performFind(mongoClient)); + // #. Ensure that the request callback has been called once and the refresh callback has not been called. + assertEquals(1, onRequest.getInvocations()); + assertEquals(0, onRefresh.getInvocations()); + + failCommand(391, 2, "find"); + executeAll(2, () -> performFind(mongoClient)); + + // #. Ensure that the request callback has been called once and the refresh callback has been called once. + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + } + } + + public MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcRequestCallback onRequest, + @Nullable final OidcRefreshCallback onRefresh) { + return createSettings(connectionString, onRequest, onRefresh, null, null); + } + + private MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcRequestCallback onRequest, + @Nullable final OidcRefreshCallback onRefresh, + @Nullable final List allowedHosts, + @Nullable final CommandListener commandListener) { + ConnectionString cs = new ConnectionString(connectionString); + MongoCredential credential = cs.getCredential() + .withMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, onRequest) + .withMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, onRefresh) + .withMechanismProperty(ALLOWED_HOSTS_KEY, allowedHosts); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applicationName(appName) + .applyConnectionString(cs) + .credential(credential); + if (commandListener != null) { + builder.addCommandListener(commandListener); + } + return builder.build(); + } + + private void performFind(final MongoClientSettings settings) { + try (MongoClient mongoClient = createMongoClient(settings)) { + performFind(mongoClient); + } + } + + private void performFind( + final MongoClientSettings settings, + final Class expectedExceptionOrCause, + final String expectedMessage) { + try (MongoClient mongoClient = createMongoClient(settings)) { + assertCause(expectedExceptionOrCause, expectedMessage, () -> performFind(mongoClient)); + } + } + + private void performFind(final MongoClient mongoClient) { + mongoClient + .getDatabase("test") + .getCollection("test") + .find() + .first(); + } + + private static void assertCause( + final Class expectedCause, final String expectedMessageFragment, final Executable e) { + Throwable actualException = assertThrows(Throwable.class, e); + assertCause(expectedCause, expectedMessageFragment, actualException); + } + + private static void assertCause( + final Class expectedCause, final String expectedMessageFragment, final Throwable actualException) { + Throwable cause = actualException; + while (cause.getCause() != null) { + cause = cause.getCause(); + } + if (!expectedCause.isInstance(cause)) { + throw new AssertionFailedError("Unexpected cause", actualException); + } + if (!cause.getMessage().contains(expectedMessageFragment)) { + throw new AssertionFailedError("Unexpected message", actualException); + } + } + + protected void delayNextFind() { + try (MongoClient client = createMongoClient(createSettings(AWS_OIDC_URL, null, null))) { + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(1))) + .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) + .append("failCommands", new BsonArray(asList(new BsonString("find")))) + .append("blockConnection", new BsonBoolean(true)) + .append("blockTimeMS", new BsonInt32(100))); + client.getDatabase("admin").runCommand(failPointDocument); + } + } + + protected void failCommand(final int code, final int times, final String... commands) { + try (MongoClient mongoClient = createMongoClient(createSettings( + AWS_OIDC_URL, null, null))) { + List list = Arrays.stream(commands).map(c -> new BsonString(c)).collect(Collectors.toList()); + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(times))) + .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) + .append("failCommands", new BsonArray(list)) + .append("errorCode", new BsonInt32(code))); + mongoClient.getDatabase("admin").runCommand(failPointDocument); + } + } + + public static class TestCallback implements OidcRequestCallback, OidcRefreshCallback { + private final AtomicInteger invocations = new AtomicInteger(); + @Nullable + private final Integer expiresInSeconds; + @Nullable + private final Integer delayInMilliseconds; + @Nullable + private final AtomicInteger concurrentTracker; + @Nullable + private final TestListener testListener; + @Nullable + private final Supplier pathSupplier; + + public TestCallback() { + this(60 * 60, null, new AtomicInteger(), null, null); + } + + public TestCallback( + @Nullable final Integer expiresInSeconds, + @Nullable final Integer delayInMilliseconds, + @Nullable final AtomicInteger concurrentTracker, + @Nullable final TestListener testListener, + @Nullable final Supplier pathSupplier) { + this.expiresInSeconds = expiresInSeconds; + this.delayInMilliseconds = delayInMilliseconds; + this.concurrentTracker = concurrentTracker; + this.testListener = testListener; + this.pathSupplier = pathSupplier; + } + + public int getInvocations() { + return invocations.get(); + } + + @Override + public IdpResponse onRequest(final OidcRequestContext context) { + if (testListener != null) { + testListener.add("onRequest invoked"); + } + return callback(); + } + + @Override + public IdpResponse onRefresh(final OidcRefreshContext context) { + if (context.getRefreshToken() == null) { + throw new IllegalArgumentException("refreshToken was null"); + } + if (testListener != null) { + testListener.add("onRefresh invoked"); + } + return callback(); + } + + @NotNull + private IdpResponse callback() { + if (concurrentTracker != null) { + if (concurrentTracker.get() > 0) { + throw new RuntimeException("Callbacks should not be invoked by multiple threads."); + } + concurrentTracker.incrementAndGet(); + } + try { + invocations.incrementAndGet(); + Path path = Paths.get(pathSupplier == null + ? getenv(AWS_WEB_IDENTITY_TOKEN_FILE) + : pathSupplier.get()); + String accessToken; + try { + simulateDelay(); + accessToken = new String(Files.readAllBytes(path), StandardCharsets.UTF_8); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + String refreshToken = "refreshToken"; + if (testListener != null) { + testListener.add("read access token: " + path.getFileName()); + } + return new IdpResponse( + accessToken, + expiresInSeconds, + refreshToken); + } finally { + if (concurrentTracker != null) { + concurrentTracker.decrementAndGet(); + } + } + } + + private void simulateDelay() throws InterruptedException { + if (delayInMilliseconds != null) { + Thread.sleep(delayInMilliseconds); + } + } + + public TestCallback setExpiresInSeconds(final Integer expiresInSeconds) { + return new TestCallback( + expiresInSeconds, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + + public TestCallback setDelayMs(final int milliseconds) { + return new TestCallback( + this.expiresInSeconds, + milliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + + public TestCallback setConcurrentTracker(final AtomicInteger c) { + return new TestCallback( + this.expiresInSeconds, + this.delayInMilliseconds, + c, + this.testListener, + this.pathSupplier); + } + + public TestCallback setEventListener(final TestListener testListener) { + return new TestCallback( + this.expiresInSeconds, + this.delayInMilliseconds, + this.concurrentTracker, + testListener, + this.pathSupplier); + } + + public TestCallback setPathSupplier(final Supplier pathSupplier) { + return new TestCallback( + this.expiresInSeconds, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + pathSupplier); + } + + public TestCallback setExpired() { + return this.setExpiresInSeconds(60); + } + } + + public TestCallback createCallback() { + return new TestCallback(); + } +}