From f6ce5d55e8029bdb48618c1f128a41b488488c60 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 5 Apr 2023 11:20:18 -0600 Subject: [PATCH 01/19] Add unified test data --- .../auth/{ => legacy}/connection-string.json | 129 +++++++++++- .../auth/reauthenticate_with_retry.json | 191 ++++++++++++++++++ .../auth/reauthenticate_without_retry.json | 191 ++++++++++++++++++ .../client/unified/UnifiedAuthTest.java | 39 ++++ 4 files changed, 549 insertions(+), 1 deletion(-) rename driver-core/src/test/resources/auth/{ => legacy}/connection-string.json (76%) create mode 100644 driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json create mode 100644 driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json create mode 100644 driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java diff --git a/driver-core/src/test/resources/auth/connection-string.json b/driver-core/src/test/resources/auth/legacy/connection-string.json similarity index 76% rename from driver-core/src/test/resources/auth/connection-string.json rename to driver-core/src/test/resources/auth/legacy/connection-string.json index 2a37ae8df47..3aa0ae2dc4d 100644 --- a/driver-core/src/test/resources/auth/connection-string.json +++ b/driver-core/src/test/resources/auth/legacy/connection-string.json @@ -444,6 +444,133 @@ "AWS_SESSION_TOKEN": "token!@#$%^&*()_+" } } + }, + { + "description": "should recognise the mechanism and request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with request callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with request and refresh callback (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest", "oidcRefresh"], + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true, + "REFRESH_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism and username with request callback (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": true, + "credential": { + "username": "principalName", + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "REQUEST_TOKEN_CALLBACK": true + } + } + }, + { + "description": "should recognise the mechanism with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "PROVIDER_NAME": "aws" + } + } + }, + { + "description": "should recognise the mechanism when auth source is explicitly specified and with aws device (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authSource=$external&authMechanismProperties=PROVIDER_NAME:aws", + "valid": true, + "credential": { + "username": null, + "password": null, + "source": "$external", + "mechanism": "MONGODB-OIDC", + "mechanism_properties": { + "PROVIDER_NAME": "aws" + } + } + }, + { + "description": "should throw an exception if username and password are specified (MONGODB-OIDC)", + "uri": "mongodb://user:pass@localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRequest"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if username and deviceName are specified (MONGODB-OIDC)", + "uri": "mongodb://principalName@localhost/?authMechanism=MONGODB-OIDC&PROVIDER_NAME:gcp", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if specified deviceName is not supported (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:unexisted", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if neither deviceName nor callbacks specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when only refresh callback is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC", + "callback": ["oidcRefresh"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", + "valid": false, + "credential": null } ] -} +} \ No newline at end of file diff --git a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json new file mode 100644 index 00000000000..c99ebc6ece2 --- /dev/null +++ b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_with_retry.json @@ -0,0 +1,191 @@ +{ + "description": "reauthenticate_with_retry", + "schemaVersion": "1.12", + "runOnRequirements": [ + { + "minServerVersion": "6.3", + "auth": true + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": true, + "retryWrites": true + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "db", + "documents": [] + } + ], + "tests": [ + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": {} + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=true", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json new file mode 100644 index 00000000000..799057bf74f --- /dev/null +++ b/driver-core/src/test/resources/unified-test-format/auth/reauthenticate_without_retry.json @@ -0,0 +1,191 @@ +{ + "description": "reauthenticate_without_retry", + "schemaVersion": "1.12", + "runOnRequirements": [ + { + "minServerVersion": "6.3", + "auth": true + } + ], + "createEntities": [ + { + "client": { + "id": "client0", + "uriOptions": { + "retryReads": false, + "retryWrites": false + }, + "observeEvents": [ + "commandStartedEvent", + "commandSucceededEvent", + "commandFailedEvent" + ] + } + }, + { + "database": { + "id": "database0", + "client": "client0", + "databaseName": "db" + } + }, + { + "collection": { + "id": "collection0", + "database": "database0", + "collectionName": "collName" + } + } + ], + "initialData": [ + { + "collectionName": "collName", + "databaseName": "db", + "documents": [] + } + ], + "tests": [ + { + "description": "Read command should reauthenticate when receive ReauthenticationRequired error code and retryReads=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "find" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "find", + "arguments": { + "filter": {} + }, + "object": "collection0", + "expectResult": [] + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandFailedEvent": { + "commandName": "find" + } + }, + { + "commandStartedEvent": { + "command": { + "find": "collName", + "filter": {} + } + } + }, + { + "commandSucceededEvent": { + "commandName": "find" + } + } + ] + } + ] + }, + { + "description": "Write command should reauthenticate when receive ReauthenticationRequired error code and retryWrites=false", + "operations": [ + { + "name": "failPoint", + "object": "testRunner", + "arguments": { + "client": "client0", + "failPoint": { + "configureFailPoint": "failCommand", + "mode": { + "times": 1 + }, + "data": { + "failCommands": [ + "insert" + ], + "errorCode": 391 + } + } + } + }, + { + "name": "insertOne", + "object": "collection0", + "arguments": { + "document": { + "_id": 1, + "x": 1 + } + } + } + ], + "expectEvents": [ + { + "client": "client0", + "events": [ + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandFailedEvent": { + "commandName": "insert" + } + }, + { + "commandStartedEvent": { + "command": { + "insert": "collName", + "documents": [ + { + "_id": 1, + "x": 1 + } + ] + } + } + }, + { + "commandSucceededEvent": { + "commandName": "insert" + } + } + ] + } + ] + } + ] +} \ No newline at end of file diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java new file mode 100644 index 00000000000..f94977f2546 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedAuthTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client.unified; + +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.junit.runners.Parameterized; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Collection; + +public class UnifiedAuthTest extends UnifiedSyncTest { + public UnifiedAuthTest(@SuppressWarnings("unused") final String fileDescription, + @SuppressWarnings("unused") final String testDescription, + final String schemaVersion, final BsonArray runOnRequirements, final BsonArray entitiesArray, + final BsonArray initialData, final BsonDocument definition) { + super(schemaVersion, runOnRequirements, entitiesArray, initialData, definition); + } + + @Parameterized.Parameters(name = "{0}: {1}") + public static Collection data() throws URISyntaxException, IOException { + return getTestData("unified-test-format/auth"); + } +} From 23db52bec538c37720f22e8c92e88aa706dcff9d Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 5 Apr 2023 14:59:53 -0600 Subject: [PATCH 02/19] Scripts --- .evergreen/prepare-oidc-get-tokens-docker.sh | 50 ++++++++++++++++++++ .evergreen/prepare-oidc-server-docker.sh | 50 ++++++++++++++++++++ 2 files changed, 100 insertions(+) create mode 100755 .evergreen/prepare-oidc-get-tokens-docker.sh create mode 100755 .evergreen/prepare-oidc-server-docker.sh diff --git a/.evergreen/prepare-oidc-get-tokens-docker.sh b/.evergreen/prepare-oidc-get-tokens-docker.sh new file mode 100755 index 00000000000..e904d5d2b89 --- /dev/null +++ b/.evergreen/prepare-oidc-get-tokens-docker.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# DRIVERS_TOOLS The path to evergreeen tools +# OIDC_AWS_* Required OIDC_AWS_* env variables must be configured +# +# Environment variables used as output: +# OIDC_TESTS_ENABLED Allows running OIDC tests +# OIDC_TOKEN_DIR The path to generated OIDC AWS tokens +# AWS_WEB_IDENTITY_TOKEN_FILE The path to AWS token for device workflow + +if [ -z ${DRIVERS_TOOLS+x} ]; then + echo "DRIVERS_TOOLS. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ROLE_ARN+x} ]; then + echo "OIDC_AWS_ROLE_ARN. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_SECRET_ACCESS_KEY+x} ]; then + echo "OIDC_AWS_SECRET_ACCESS_KEY. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ACCESS_KEY_ID+x} ]; then + echo "OIDC_AWS_ACCESS_KEY_ID. is not set"; + exit 1 +fi + +export AWS_ROLE_ARN=${OIDC_AWS_ROLE_ARN} +export AWS_SECRET_ACCESS_KEY=${OIDC_AWS_SECRET_ACCESS_KEY} +export AWS_ACCESS_KEY_ID=${OIDC_AWS_ACCESS_KEY_ID} +export OIDC_FOLDER=${DRIVERS_TOOLS}/.evergreen/auth_oidc +export OIDC_TOKEN_DIR=${OIDC_FOLDER}/test_tokens +export AWS_WEB_IDENTITY_TOKEN_FILE=${OIDC_TOKEN_DIR}/test1 +export OIDC_TESTS_ENABLED=true + +echo "Configuring OIDC server for local authentication tests" + +cd ${OIDC_FOLDER} +DRIVERS_TOOLS=${DRIVERS_TOOLS} ./oidc_get_tokens.sh \ No newline at end of file diff --git a/.evergreen/prepare-oidc-server-docker.sh b/.evergreen/prepare-oidc-server-docker.sh new file mode 100755 index 00000000000..0fcd1ed4194 --- /dev/null +++ b/.evergreen/prepare-oidc-server-docker.sh @@ -0,0 +1,50 @@ +#!/bin/bash + +set -o xtrace +set -o errexit # Exit the script with error if any of the commands fail + +############################################ +# Main Program # +############################################ + +# Supported/used environment variables: +# DRIVERS_TOOLS The path to evergreeen tools +# OIDC_AWS_* OIDC_AWS_* env variables must be configured +# +# Environment variables used as output: +# OIDC_TESTS_ENABLED Allows running OIDC tests +# OIDC_TOKEN_DIR The path to generated tokens +# AWS_WEB_IDENTITY_TOKEN_FILE The path to AWS token for device workflow + +if [ -z ${DRIVERS_TOOLS+x} ]; then + echo "DRIVERS_TOOLS. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ROLE_ARN+x} ]; then + echo "OIDC_AWS_ROLE_ARN. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_SECRET_ACCESS_KEY+x} ]; then + echo "OIDC_AWS_SECRET_ACCESS_KEY. is not set"; + exit 1 +fi + +if [ -z ${OIDC_AWS_ACCESS_KEY_ID+x} ]; then + echo "OIDC_AWS_ACCESS_KEY_ID. is not set"; + exit 1 +fi + +export AWS_ROLE_ARN=${OIDC_AWS_ROLE_ARN} +export AWS_SECRET_ACCESS_KEY=${OIDC_AWS_SECRET_ACCESS_KEY} +export AWS_ACCESS_KEY_ID=${OIDC_AWS_ACCESS_KEY_ID} +export OIDC_FOLDER=${DRIVERS_TOOLS}/.evergreen/auth_oidc +export OIDC_TOKEN_DIR=${OIDC_FOLDER}/test_tokens +export AWS_WEB_IDENTITY_TOKEN_FILE=${OIDC_TOKEN_DIR}/test1 +export OIDC_TESTS_ENABLED=true + +echo "Configuring OIDC server for local authentication tests" + +cd ${OIDC_FOLDER} +DRIVERS_TOOLS=${DRIVERS_TOOLS} ./start_local_server.sh \ No newline at end of file From db2880d2e57398453c783b5babd9f75f29e7b996 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Thu, 18 May 2023 16:11:52 -0600 Subject: [PATCH 03/19] Implement OIDC auth for sync --- .../src/test/unit/util/ThreadTestHelpers.java | 12 +- .../com/mongodb/AuthenticationMechanism.java | 7 + .../main/com/mongodb/ConnectionString.java | 5 + .../src/main/com/mongodb/MongoCredential.java | 217 +++- .../internal/connection/Authenticator.java | 12 + .../internal/connection/AwsAuthenticator.java | 41 +- .../connection/InternalStreamConnection.java | 70 +- .../InternalStreamConnectionFactory.java | 23 +- .../InternalStreamConnectionInitializer.java | 20 +- .../connection/MongoCredentialWithCache.java | 30 +- .../connection/OidcAuthenticator.java | 664 +++++++++++ .../connection/SaslAuthenticator.java | 61 +- .../connection/ScramShaAuthenticator.java | 48 +- .../com/mongodb/client/TestHelper.java | 47 + .../com/mongodb/client/TestListener.java | 40 + .../connection/TestCommandListener.java | 38 +- .../auth/legacy/connection-string.json | 14 + .../com/mongodb/AuthConnectionStringTest.java | 51 +- .../client/OidcAuthenticationProseTests.java | 1011 +++++++++++++++++ 19 files changed, 2288 insertions(+), 123 deletions(-) create mode 100644 driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java create mode 100644 driver-core/src/test/functional/com/mongodb/client/TestHelper.java create mode 100644 driver-core/src/test/functional/com/mongodb/client/TestListener.java create mode 100644 driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java diff --git a/bson/src/test/unit/util/ThreadTestHelpers.java b/bson/src/test/unit/util/ThreadTestHelpers.java index a4767c503f9..e2115da079f 100644 --- a/bson/src/test/unit/util/ThreadTestHelpers.java +++ b/bson/src/test/unit/util/ThreadTestHelpers.java @@ -31,15 +31,19 @@ private ThreadTestHelpers() { } public static void executeAll(final int nThreads, final Runnable c) { + executeAll(Collections.nCopies(nThreads, c).toArray(new Runnable[0])); + } + + public static void executeAll(final Runnable... runnables) { ExecutorService service = null; try { - service = Executors.newFixedThreadPool(nThreads); - CountDownLatch latch = new CountDownLatch(nThreads); + service = Executors.newFixedThreadPool(runnables.length); + CountDownLatch latch = new CountDownLatch(runnables.length); List failures = Collections.synchronizedList(new ArrayList<>()); - for (int i = 0; i < nThreads; i++) { + for (final Runnable runnable : runnables) { service.submit(() -> { try { - c.run(); + runnable.run(); } catch (Throwable e) { failures.add(e); } finally { diff --git a/driver-core/src/main/com/mongodb/AuthenticationMechanism.java b/driver-core/src/main/com/mongodb/AuthenticationMechanism.java index db8a909b79d..7a7b7415ef6 100644 --- a/driver-core/src/main/com/mongodb/AuthenticationMechanism.java +++ b/driver-core/src/main/com/mongodb/AuthenticationMechanism.java @@ -37,6 +37,13 @@ public enum AuthenticationMechanism { */ MONGODB_AWS("MONGODB-AWS"), + /** + * The MONGODB-OIDC mechanism. + * @since 4.10 + * @mongodb.server.release 7.0 + */ + MONGODB_OIDC("MONGODB-OIDC"), + /** * The MongoDB X.509 mechanism. This mechanism is available only with client certificates over SSL. */ diff --git a/driver-core/src/main/com/mongodb/ConnectionString.java b/driver-core/src/main/com/mongodb/ConnectionString.java index cd8f359f35c..902a57bc495 100644 --- a/driver-core/src/main/com/mongodb/ConnectionString.java +++ b/driver-core/src/main/com/mongodb/ConnectionString.java @@ -40,6 +40,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateCreateOidcCredential; import static java.lang.String.format; import static java.util.Arrays.asList; import static java.util.Collections.singletonList; @@ -919,6 +920,10 @@ private MongoCredential createMongoCredentialWithMechanism(final AuthenticationM case MONGODB_AWS: credential = MongoCredential.createAwsCredential(userName, password); break; + case MONGODB_OIDC: + validateCreateOidcCredential(password); + credential = MongoCredential.createOidcCredential(userName); + break; default: throw new UnsupportedOperationException(format("The connection string contains an invalid authentication mechanism'. " + "'%s' is not a supported authentication mechanism", diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index ffa2a3c4e02..ff3d4cc168d 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -20,19 +20,23 @@ import com.mongodb.annotations.Immutable; import com.mongodb.lang.Nullable; +import java.time.Duration; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; import static com.mongodb.AuthenticationMechanism.GSSAPI; import static com.mongodb.AuthenticationMechanism.MONGODB_AWS; +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; import static com.mongodb.AuthenticationMechanism.MONGODB_X509; import static com.mongodb.AuthenticationMechanism.PLAIN; import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_1; import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_256; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateOidcCredentialConstruction; /** * Represents credentials to authenticate to a mongo server,as well as the source of the credentials and the authentication mechanism to @@ -179,6 +183,68 @@ public final class MongoCredential { @Beta(Beta.Reason.CLIENT) public static final String AWS_CREDENTIAL_PROVIDER_KEY = "AWS_CREDENTIAL_PROVIDER"; + /** + * The provider name. The value must be a string. + *

+ * If this is provided, neither + * {@link MongoCredential#REQUEST_TOKEN_CALLBACK_KEY} nor + * {@link MongoCredential#REFRESH_TOKEN_CALLBACK_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String PROVIDER_NAME_KEY = "PROVIDER_NAME"; + + /** + * This callback is invoked when the OIDC-based authenticator requests + * tokens from the identity provider. The type of the value must be + * {@link OidcRequestCallback}. + *

+ * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String REQUEST_TOKEN_CALLBACK_KEY = "REQUEST_TOKEN_CALLBACK"; + + /** + * Mechanism key for invoked when the OIDC-based authenticator refreshes + * tokens from the identity provider. If this callback is not provided, + * then refresh operations will not be attempted.The type of the value + * must be {@link OidcRefreshCallback}. + *

+ * If this is provided, {@link MongoCredential#PROVIDER_NAME_KEY} + * must not be provided. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String REFRESH_TOKEN_CALLBACK_KEY = "REFRESH_TOKEN_CALLBACK"; + + /** + * Mechanism key for a list of allowed hostnames or ip-addresses (ignoring ports) for MongoDB connections. + * The hostnames may include a leading "*." wildcard, which allows for matching (potentially nested) subdomains. + * When MONGODB-OIDC authentication is attempted against a hostname that does not match any of list of allowed hosts + * the driver will raise an error. The type of the value must be {@code List}. + * + * @see MongoCredential#DEFAULT_ALLOWED_HOSTS + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final String ALLOWED_HOSTS_KEY = "ALLOWED_HOSTS"; + + /** + * The list of allowed hosts that will be used if no + * {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied. + * + * @see #createOidcCredential(String) + * @since 4.10 + */ + public static final List DEFAULT_ALLOWED_HOSTS = Collections.unmodifiableList(Arrays.asList( + "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")); + /** * Creates a MongoCredential instance with an unspecified mechanism. The client will negotiate the best mechanism based on the * version of the server that the client is authenticating to. @@ -327,6 +393,23 @@ public static MongoCredential createAwsCredential(@Nullable final String userNam return new MongoCredential(MONGODB_AWS, userName, "$external", password); } + /** + * Creates a MongoCredential instance for the MONGODB-OIDC mechanism. + * + * @param userName the user name, which may be null. This is the OIDC principal name. + * @return the credential + * @since 4.10 + * @see #withMechanismProperty(String, Object) + * @see #PROVIDER_NAME_KEY + * @see #REQUEST_TOKEN_CALLBACK_KEY + * @see #REFRESH_TOKEN_CALLBACK_KEY + * @see #ALLOWED_HOSTS_KEY + * @mongodb.server.release 7.0 + */ + public static MongoCredential createOidcCredential(@Nullable final String userName) { + return new MongoCredential(MONGODB_OIDC, userName, "$external", null); + } + /** * Creates a new MongoCredential as a copy of this instance, with the specified mechanism property added. * @@ -370,7 +453,11 @@ public MongoCredential withMechanism(final AuthenticationMechanism mechanism) { MongoCredential(@Nullable final AuthenticationMechanism mechanism, @Nullable final String userName, final String source, @Nullable final char[] password, final Map mechanismProperties) { - if (userName == null && !Arrays.asList(MONGODB_X509, MONGODB_AWS).contains(mechanism)) { + if (mechanism == MONGODB_OIDC) { + validateOidcCredentialConstruction(source, mechanismProperties); + } + + if (userName == null && !Arrays.asList(MONGODB_X509, MONGODB_AWS, MONGODB_OIDC).contains(mechanism)) { throw new IllegalArgumentException("username can not be null"); } @@ -543,4 +630,132 @@ public String toString() { + ", mechanismProperties=" + '}'; } + + /** + * The context for the {@link OidcRequestCallback#onRequest(OidcRequestContext) OIDC request callback}. + */ + public interface OidcRequestContext { + /** + * @return The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + IdpInfo getIdpInfo(); + + /** + * @return The timeout that this callback must complete within. + */ + Duration getTimeout(); + } + + /** + * The context for the {@link OidcRefreshCallback#onRefresh(OidcRefreshContext) OIDC refresh callback}. + */ + public interface OidcRefreshContext extends OidcRequestContext { + /** + * @return The OIDC Refresh token supplied by a prior callback invocation. + */ + String getRefreshToken(); + } + + /** + * This callback is invoked when the OIDC-based authenticator requests + * tokens from the identity provider. + *

+ * It does not have to be thread-safe, unless it is provided to multiple + * MongoClients. + */ + public interface OidcRequestCallback { + /** + * @param context The context. + * @return The response produced by an OIDC Identity Provider + */ + IdpResponse onRequest(OidcRequestContext context); + } + + /** + * This callback is invoked when the OIDC-based authenticator refreshes + * tokens from the identity provider. If this callback is not provided, + * then refresh operations will not be attempted. + *

+ * It does not have to be thread-safe, unless it is provided to multiple + * MongoClients. + */ + public interface OidcRefreshCallback { + /** + * @param context The context. + * @return The response produced by an OIDC Identity Provider + */ + IdpResponse onRefresh(OidcRefreshContext context); + } + + /** + * The OIDC Identity Provider's configuration that can be used to acquire an Access Token. + */ + public interface IdpInfo { + /** + * @return URL which describes the Authorization Server. This identifier is the + * iss of provided access tokens, and is viable for RFC8414 metadata + * discovery and RFC9207 identification. + */ + String getIssuer(); + + /** + * @return Unique client ID for this OIDC client. + */ + String getClientId(); + + /** + * @return Additional scopes to request from Identity Provider. Immutable. + */ + List getRequestScopes(); + } + + /** + * The response produced by an OIDC Identity Provider. + */ + public static final class IdpResponse { + + private final String accessToken; + + @Nullable + private final Integer expiresInSeconds; + + @Nullable + private final String refreshToken; + + /** + * @param accessToken The OIDC access token + * @param expiresInSeconds The expiration in seconds. If null, the access token is single-use. + * @param refreshToken The refresh token. If null, refresh will not be attempted. + */ + public IdpResponse(final String accessToken, @Nullable final Integer expiresInSeconds, + @Nullable final String refreshToken) { + notNull("accessToken", accessToken); + this.accessToken = accessToken; + this.expiresInSeconds = expiresInSeconds; + this.refreshToken = refreshToken; + } + + /** + * @return The OIDC access token. + */ + public String getAccessToken() { + return accessToken; + } + + /** + * @return The expiration time in seconds. If null, the access token is single-use. + */ + @Nullable + public Integer getExpiresInSeconds() { + return expiresInSeconds; + } + + /** + * @return The OIDC refresh token. If null, refresh will not be attempted. + */ + @Nullable + public String getRefreshToken() { + return refreshToken; + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index 9ec4780d958..96c66affb0c 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -21,6 +21,7 @@ import com.mongodb.ServerApi; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; +import com.mongodb.connection.ServerType; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; @@ -42,6 +43,11 @@ public abstract class Authenticator { this.serverApi = serverApi; } + public static boolean shouldAuthenticate(@Nullable final Authenticator authenticator, + final ConnectionDescription connectionDescription) { + return authenticator != null && connectionDescription.getServerType() != ServerType.REPLICA_SET_ARBITER; + } + @NonNull MongoCredentialWithCache getMongoCredentialWithCache() { return credential; @@ -93,4 +99,10 @@ T getNonNullMechanismProperty(final String key, @Nullable final T defaultVal abstract void authenticateAsync(InternalConnection connection, ConnectionDescription connectionDescription, SingleResultCallback callback); + + public void reauthenticate(final InternalConnection connection) { + throw new UnsupportedOperationException( + "Reauthentication requested by server but is not supported by specified mechanism."); + } + } diff --git a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java index ec0fc3f9c8f..c2382bc9ba5 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java @@ -16,7 +16,6 @@ package com.mongodb.internal.connection; -import com.mongodb.AuthenticationMechanism; import com.mongodb.AwsCredential; import com.mongodb.MongoClientException; import com.mongodb.MongoCredential; @@ -77,27 +76,12 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { return new AwsSaslClient(getMongoCredential()); } - private static class AwsSaslClient implements SaslClient { - private final MongoCredential credential; + private static class AwsSaslClient extends SaslClientImpl { private final byte[] clientNonce = new byte[RANDOM_LENGTH]; private int step = -1; AwsSaslClient(final MongoCredential credential) { - this.credential = credential; - } - - @Override - public String getMechanismName() { - AuthenticationMechanism authMechanism = credential.getAuthenticationMechanism(); - if (authMechanism == null) { - throw new IllegalArgumentException("Authentication mechanism cannot be null"); - } - return authMechanism.getMechanismName(); - } - - @Override - public boolean hasInitialResponse() { - return true; + super(credential); } @Override @@ -117,26 +101,6 @@ public boolean isComplete() { return step == 1; } - @Override - public byte[] unwrap(final byte[] bytes, final int i, final int i1) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public byte[] wrap(final byte[] bytes, final int i, final int i1) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public Object getNegotiatedProperty(final String s) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - @Override - public void dispose() { - // nothing to do - } - private byte[] computeClientFirstMessage() { new SecureRandom().nextBytes(this.clientNonce); @@ -184,6 +148,7 @@ private byte[] computeClientFinalMessage(final byte[] serverFirst) throws SaslEx private AwsCredential createAwsCredential() { AwsCredential awsCredential; + MongoCredential credential = getCredential(); if (credential.getUserName() != null) { if (credential.getPassword() == null) { throw new MongoClientException("secretAccessKey is required for AWS credential"); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index e5526c175d6..894ac0a466e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -18,6 +18,7 @@ import com.mongodb.LoggerSettings; import com.mongodb.MongoClientException; +import com.mongodb.MongoCommandException; import com.mongodb.MongoCompressor; import com.mongodb.MongoException; import com.mongodb.MongoInternalException; @@ -68,6 +69,7 @@ import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.isTrue; @@ -95,6 +97,19 @@ @NotThreadSafe public class InternalStreamConnection implements InternalConnection { + private static volatile boolean recordEverything = false; + + /** + * Will attempt to record events to the command listener that are usually + * suppressed. + * + * @param recordEverything whether to attempt to record everything + */ + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + public static void setRecordEverything(final boolean recordEverything) { + InternalStreamConnection.recordEverything = recordEverything; + } + private static final Set SECURITY_SENSITIVE_COMMANDS = new HashSet<>(asList( "authenticate", "saslStart", @@ -114,6 +129,8 @@ public class InternalStreamConnection implements InternalConnection { private static final Logger LOGGER = Loggers.getLogger("connection"); private final ClusterConnectionMode clusterConnectionMode; + @Nullable + private final Authenticator authenticator; private final boolean isMonitoringConnection; private final ServerId serverId; private final ConnectionGenerationSupplier connectionGenerationSupplier; @@ -127,6 +144,7 @@ public class InternalStreamConnection implements InternalConnection { private final AtomicBoolean isClosed = new AtomicBoolean(); private final AtomicBoolean opened = new AtomicBoolean(); + private final AtomicBoolean authenticated = new AtomicBoolean(); private final List compressorList; private final LoggerSettings loggerSettings; @@ -153,11 +171,13 @@ public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMod final StreamFactory streamFactory, final List compressorList, final CommandListener commandListener, final InternalConnectionInitializer connectionInitializer, @Nullable final InetAddressResolver inetAddressResolver) { - this(clusterConnectionMode, false, serverId, connectionGenerationSupplier, streamFactory, compressorList, + this(clusterConnectionMode, null, false, serverId, connectionGenerationSupplier, streamFactory, compressorList, LoggerSettings.builder().build(), commandListener, connectionInitializer, inetAddressResolver); } - public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMode, final boolean isMonitoringConnection, + public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMode, + @Nullable final Authenticator authenticator, + final boolean isMonitoringConnection, final ServerId serverId, final ConnectionGenerationSupplier connectionGenerationSupplier, final StreamFactory streamFactory, final List compressorList, @@ -165,6 +185,7 @@ public InternalStreamConnection(final ClusterConnectionMode clusterConnectionMod final CommandListener commandListener, final InternalConnectionInitializer connectionInitializer, @Nullable final InetAddressResolver inetAddressResolver) { this.clusterConnectionMode = clusterConnectionMode; + this.authenticator = authenticator; this.isMonitoringConnection = isMonitoringConnection; this.serverId = notNull("serverId", serverId); this.connectionGenerationSupplier = notNull("connectionGeneration", connectionGenerationSupplier); @@ -287,6 +308,7 @@ private void initAfterHandshakeFinish(final InternalConnectionInitializationDesc description = initializationDescription.getConnectionDescription(); initialServerDescription = initializationDescription.getServerDescription(); opened.set(true); + authenticated.set(true); sendCompressor = findSendCompressor(description); } @@ -352,8 +374,38 @@ public boolean isClosed() { @Override public T sendAndReceive(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext) { - CommandEventSender commandEventSender; + if (!Authenticator.shouldAuthenticate(authenticator, this.description)) { + return sendAndReceiveInternal(message, decoder, sessionContext, requestContext, operationContext); + } + Supplier retryableOperation = () -> + sendAndReceiveInternal(message, decoder, sessionContext, requestContext, operationContext); + try { + return retryableOperation.get(); + } catch (MongoCommandException e) { + if (triggersReauthentication(e)) { + authenticated.set(false); + authenticator.reauthenticate(this); + authenticated.set(true); + return retryableOperation.get(); + } + throw e; + } + } + + public static boolean triggersReauthentication(@Nullable final Throwable t) { + if (t instanceof MongoCommandException) { + MongoCommandException e = (MongoCommandException) t; + return e.getErrorCode() == 391; + } + return false; + } + + @Nullable + private T sendAndReceiveInternal(final CommandMessage message, final Decoder decoder, + final SessionContext sessionContext, final RequestContext requestContext, + final OperationContext operationContext) { + CommandEventSender commandEventSender; try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(this)) { message.encode(bsonOutput, sessionContext); commandEventSender = createCommandEventSender(message, bsonOutput, requestContext, operationContext); @@ -466,7 +518,7 @@ private T receiveCommandMessageResponse(final Decoder decoder, commandEventSender.sendFailedEvent(e); } throw e; - } + } } @Override @@ -860,12 +912,14 @@ public void onResult(@Nullable final ByteBuf result, @Nullable final Throwable t private CommandEventSender createCommandEventSender(final CommandMessage message, final ByteBufferBsonOutput bsonOutput, final RequestContext requestContext, final OperationContext operationContext) { - if (!isMonitoringConnection && opened() && (commandListener != null || COMMAND_PROTOCOL_LOGGER.isRequired(DEBUG, getClusterId()))) { - return new LoggingCommandEventSender(SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, - commandListener, requestContext, operationContext, message, bsonOutput, COMMAND_PROTOCOL_LOGGER, loggerSettings); - } else { + boolean listensOrLogs = commandListener != null || COMMAND_PROTOCOL_LOGGER.isRequired(DEBUG, getClusterId()); + if (!recordEverything && (isMonitoringConnection || !opened() || !authenticated.get() || !listensOrLogs)) { return new NoOpCommandEventSender(); } + return new LoggingCommandEventSender( + SECURITY_SENSITIVE_COMMANDS, SECURITY_SENSITIVE_HELLO_COMMANDS, description, commandListener, + requestContext, operationContext, message, bsonOutput, + COMMAND_PROTOCOL_LOGGER, loggerSettings); } private ClusterId getClusterId() { diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java index 2431a3b800a..9d4dc8aaaca 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java @@ -16,6 +16,7 @@ package com.mongodb.internal.connection; +import com.mongodb.AuthenticationMechanism; import com.mongodb.LoggerSettings; import com.mongodb.MongoCompressor; import com.mongodb.MongoDriverInformation; @@ -30,9 +31,9 @@ import java.util.List; -import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.connection.ClientMetadataHelper.createClientMetadataDocument; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; class InternalStreamConnectionFactory implements InternalConnectionFactory { private final ClusterConnectionMode clusterConnectionMode; @@ -80,18 +81,21 @@ class InternalStreamConnectionFactory implements InternalConnectionFactory { @Override public InternalConnection create(final ServerId serverId, final ConnectionGenerationSupplier connectionGenerationSupplier) { Authenticator authenticator = credential == null ? null : createAuthenticator(credential); - return new InternalStreamConnection(clusterConnectionMode, isMonitoringConnection, serverId, connectionGenerationSupplier, + InternalStreamConnectionInitializer connectionInitializer = new InternalStreamConnectionInitializer( + clusterConnectionMode, authenticator, clientMetadataDocument, compressorList, serverApi); + return new InternalStreamConnection( + clusterConnectionMode, authenticator, + isMonitoringConnection, serverId, connectionGenerationSupplier, streamFactory, compressorList, loggerSettings, commandListener, - new InternalStreamConnectionInitializer(clusterConnectionMode, authenticator, clientMetadataDocument, compressorList, - serverApi), inetAddressResolver); + connectionInitializer, inetAddressResolver); } private Authenticator createAuthenticator(final MongoCredentialWithCache credential) { - if (credential.getAuthenticationMechanism() == null) { + AuthenticationMechanism authenticationMechanism = credential.getAuthenticationMechanism(); + if (authenticationMechanism == null) { return new DefaultAuthenticator(credential, clusterConnectionMode, serverApi); } - - switch (assertNotNull(credential.getAuthenticationMechanism())) { + switch (authenticationMechanism) { case GSSAPI: return new GSSAPIAuthenticator(credential, clusterConnectionMode, serverApi); case PLAIN: @@ -103,8 +107,11 @@ private Authenticator createAuthenticator(final MongoCredentialWithCache credent return new ScramShaAuthenticator(credential, clusterConnectionMode, serverApi); case MONGODB_AWS: return new AwsAuthenticator(credential, clusterConnectionMode, serverApi); + case MONGODB_OIDC: + validateBeforeUse(credential.getCredential()); + return new OidcAuthenticator(credential, clusterConnectionMode, serverApi); default: - throw new IllegalArgumentException("Unsupported authentication mechanism " + credential.getAuthenticationMechanism()); + throw new IllegalArgumentException("Unsupported authentication mechanism " + authenticationMechanism); } } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java index 8b102182c05..753ff235d79 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionInitializer.java @@ -25,7 +25,6 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.connection.ConnectionId; import com.mongodb.connection.ServerDescription; -import com.mongodb.connection.ServerType; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import org.bson.BsonArray; @@ -81,8 +80,10 @@ public InternalConnectionInitializationDescription finishHandshake(final Interna final InternalConnectionInitializationDescription description) { notNull("internalConnection", internalConnection); notNull("description", description); - - authenticate(internalConnection, description.getConnectionDescription()); + final ConnectionDescription connectionDescription = description.getConnectionDescription(); + if (Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { + authenticator.authenticate(internalConnection, connectionDescription); + } return completeConnectionDescriptionInitialization(internalConnection, description); } @@ -105,11 +106,12 @@ public void startHandshakeAsync(final InternalConnection internalConnection, public void finishHandshakeAsync(final InternalConnection internalConnection, final InternalConnectionInitializationDescription description, final SingleResultCallback callback) { - if (authenticator == null || description.getConnectionDescription().getServerType() - == ServerType.REPLICA_SET_ARBITER) { + ConnectionDescription connectionDescription = description.getConnectionDescription(); + + if (!Authenticator.shouldAuthenticate(authenticator, connectionDescription)) { completeConnectionDescriptionInitializationAsync(internalConnection, description, callback); } else { - authenticator.authenticateAsync(internalConnection, description.getConnectionDescription(), + authenticator.authenticateAsync(internalConnection, connectionDescription, (result1, t1) -> { if (t1 != null) { callback.onResult(null, t1); @@ -200,12 +202,6 @@ private InternalConnectionInitializationDescription completeConnectionDescriptio description); } - private void authenticate(final InternalConnection internalConnection, final ConnectionDescription connectionDescription) { - if (authenticator != null && connectionDescription.getServerType() != ServerType.REPLICA_SET_ARBITER) { - authenticator.authenticate(internalConnection, connectionDescription); - } - } - private void setSpeculativeAuthenticateResponse(final BsonDocument helloResult) { if (authenticator instanceof SpeculativeAuthenticator) { ((SpeculativeAuthenticator) authenticator).setSpeculativeAuthenticateResponse( diff --git a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java index 6b0f6e4ec3c..50579c2b28d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java +++ b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java @@ -18,13 +18,17 @@ import com.mongodb.AuthenticationMechanism; import com.mongodb.MongoCredential; +import com.mongodb.internal.Locks; import com.mongodb.lang.Nullable; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Supplier; import static com.mongodb.internal.Locks.withLock; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcCacheEntry; + /** *

This class is not part of the public API and may be removed or changed at any time

*/ @@ -33,12 +37,13 @@ public class MongoCredentialWithCache { private final Cache cache; public MongoCredentialWithCache(final MongoCredential credential) { - this(credential, null); + this.credential = credential; + this.cache = new Cache(); } - private MongoCredentialWithCache(final MongoCredential credential, @Nullable final Cache cache) { + private MongoCredentialWithCache(final MongoCredential credential, final Cache cache) { this.credential = credential; - this.cache = cache != null ? cache : new Cache(); + this.cache = cache; } public MongoCredentialWithCache withMechanism(final AuthenticationMechanism mechanism) { @@ -63,15 +68,34 @@ public void putInCache(final Object key, final Object value) { cache.set(key, value); } + public OidcCacheEntry getOidcCacheEntry() { + return cache.oidcCacheEntry; + } + + public void setOidcCacheEntry(final OidcCacheEntry oidcCacheEntry) { + this.cache.oidcCacheEntry = oidcCacheEntry; + } + + public V withOidcLock(final Supplier k) { + return Locks.withLock(cache.oidcLock, k); + } + public Lock getLock() { return cache.lock; } + /** + * Stores any state associated with the credential. + */ static class Cache { private final ReentrantLock lock = new ReentrantLock(); private Object cacheKey; private Object cacheValue; + + private final ReentrantLock oidcLock = new ReentrantLock(); + private volatile OidcCacheEntry oidcCacheEntry = new OidcCacheEntry(); + Object get(final Object key) { return withLock(lock, () -> { if (cacheKey != null && cacheKey.equals(key)) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java new file mode 100644 index 00000000000..02b78237075 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -0,0 +1,664 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.AuthenticationMechanism; +import com.mongodb.MongoClientException; +import com.mongodb.MongoCommandException; +import com.mongodb.MongoConfigurationException; +import com.mongodb.MongoCredential; +import com.mongodb.MongoCredential.IdpInfo; +import com.mongodb.MongoCredential.IdpResponse; +import com.mongodb.MongoException; +import com.mongodb.MongoSecurityException; +import com.mongodb.ServerAddress; +import com.mongodb.ServerApi; +import com.mongodb.connection.ClusterConnectionMode; +import com.mongodb.connection.ConnectionDescription; +import com.mongodb.lang.Nullable; +import org.bson.BsonBinaryWriter; +import org.bson.BsonDocument; +import org.bson.BsonString; +import org.bson.RawBsonDocument; +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.EncoderContext; +import org.bson.io.BasicOutputBuffer; +import org.jetbrains.annotations.NotNull; + +import javax.security.sasl.SaslClient; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.DEFAULT_ALLOWED_HOSTS; +import static com.mongodb.MongoCredential.OidcRefreshContext; +import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; +import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OidcRefreshCallback; +import static com.mongodb.MongoCredential.OidcRequestCallback; +import static com.mongodb.assertions.Assertions.assertFalse; +import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.notNull; +import static java.lang.String.format; + +/** + *

This class is not part of the public API and may be removed or changed at any time

+ */ +public class OidcAuthenticator extends SaslAuthenticator { + + private static final List SUPPORTED_PROVIDERS = Arrays.asList("aws"); + + private static final Duration CALLBACK_TIMEOUT = Duration.ofMinutes(5); + + private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + + private ServerAddress serverAddress; + + private String connectionLastAccessToken; + + private FallbackState fallbackState = FallbackState.INITIAL; + + private BsonDocument speculativeAuthenticateResponse; + + private Function evaluateChallengeFunction; + + public OidcAuthenticator(final MongoCredentialWithCache credential, + final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { + super(credential, clusterConnectionMode, serverApi); + + if (getMongoCredential().getAuthenticationMechanism() != MONGODB_OIDC) { + throw new MongoException("Incorrect mechanism: " + getMongoCredential().getMechanism()); + } + } + + @Override + public String getMechanismName() { + return MONGODB_OIDC.getMechanismName(); + } + + @Override + protected SaslClient createSaslClient(final ServerAddress serverAddress) { + this.serverAddress = serverAddress; + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + return new OidcSaslClient(mongoCredentialWithCache); + } + + @Override + @Nullable + public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) { + try { + OidcRequestCallback requestCallback = getRequestCallback(); + if (requestCallback == null) { + return wrapInSpeculative(prepareAwsTokenFromFile()); + } + String cachedAccessToken = getValidCachedAccessToken(); + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + if (cachedAccessToken != null) { + connectionLastAccessToken = cachedAccessToken; + fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; + return wrapInSpeculative(prepareTokenAsJwt(cachedAccessToken)); + } else if (mongoCredentialWithCache.getOidcCacheEntry().idpInfo == null) { + String userName = mongoCredentialWithCache.getCredential().getUserName(); + return wrapInSpeculative(prepareUsername(userName)); + } else { + // otherwise, skip speculative auth + return null; + } + } catch (Exception e) { + throw wrapException(e); + } + } + + @NotNull + private BsonDocument wrapInSpeculative(final byte[] outToken) { + BsonDocument startDocument = createSaslStartCommandDocument(outToken) + .append("db", new BsonString(getMongoCredential().getSource())); + appendSaslStartOptions(startDocument); + return startDocument; + } + + @Override + @Nullable + public BsonDocument getSpeculativeAuthenticateResponse() { + BsonDocument response = speculativeAuthenticateResponse; + // response should only be read once + this.speculativeAuthenticateResponse = null; + if (response == null) { + this.connectionLastAccessToken = null; + this.fallbackState = FallbackState.INITIAL; + } + return response; + } + + @Override + public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument response) { + speculativeAuthenticateResponse = response; + } + + @Nullable + private OidcRefreshCallback getRefreshCallback() { + return getMongoCredentialWithCache() + .getCredential() + .getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); + } + + @Nullable + private OidcRequestCallback getRequestCallback() { + return getMongoCredentialWithCache() + .getCredential() + .getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); + } + + @Override + public void reauthenticate(final InternalConnection connection) { + fallbackState = FallbackState.INITIAL; + authLock(connection, connection.getDescription()); + } + + @Override + public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { + // method must only be called during original handshake; fail otherwise + assertFalse(connection.opened()); + // this method "wraps" the default authentication method in custom OIDC retry logic + String accessToken = getValidCachedAccessToken(); + if (accessToken != null) { + try { + authenticateUsing(connection, connectionDescription, (bytes) -> prepareTokenAsJwt(accessToken)); + } catch (MongoSecurityException e) { + if (triggersRetry(e)) { // TODO-OIDC-x unclear how to provide test coverage for this + authLock(connection, connectionDescription); + } + } + } else { + authLock(connection, connectionDescription); + } + } + + private static boolean triggersRetry(@Nullable final Throwable t) { + if (t instanceof MongoSecurityException) { + MongoSecurityException e = (MongoSecurityException) t; + Throwable cause = e.getCause(); + if (cause instanceof MongoCommandException) { + MongoCommandException commandCause = (MongoCommandException) cause; + return commandCause.getErrorCode() == 18; + } + } + return false; + } + + + private void authenticateUsing( + final InternalConnection connection, + final ConnectionDescription connectionDescription, + final Function evaluateChallengeFunction) { + this.evaluateChallengeFunction = evaluateChallengeFunction; + super.authenticate(connection, connectionDescription); + } + + private void authLock(final InternalConnection connection, final ConnectionDescription connectionDescription) { + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + mongoCredentialWithCache.withOidcLock(() -> { + while (true) { + try { + authenticateUsing(connection, connectionDescription, (challenge) -> evaluate(challenge)); + break; + } catch (MongoSecurityException e) { + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + if (triggersRetry(e)) { + prepareRetry(e, cacheEntry); + } else { + throw e; + } + } + } + return null; + }); + } + + private byte[] evaluate(final byte[] challenge) { + OidcRequestCallback requestCallback = getRequestCallback(); + if (requestCallback == null) { + return prepareAwsTokenFromFile(); + } + + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + String cachedAccessToken = getValidCachedAccessToken(); + String invalidConnectionAccessToken = connectionLastAccessToken; + String cachedRefreshToken = cacheEntry.refreshToken; + IdpInfo cachedIdpInfo = cacheEntry.idpInfo; + + if (cachedAccessToken != null) { + boolean cachedTokenIsInvalid = cachedAccessToken.equals(invalidConnectionAccessToken); + if (cachedTokenIsInvalid) { + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry.clearAccessToken()); + cachedAccessToken = null; + } + } + OidcRefreshCallback refreshCallback = getRefreshCallback(); + if (cachedAccessToken != null) { + fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; + return prepareTokenAsJwt(cachedAccessToken); + } else if (refreshCallback != null && cachedRefreshToken != null && cachedIdpInfo != null) { + fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; + // Invoke Refresh Callback using cached Refresh Token + validateAllowedHosts(getMongoCredential()); + IdpResponse result = refreshCallback.onRefresh(new OidcRefreshContextImpl( + cachedIdpInfo, cachedRefreshToken, CALLBACK_TIMEOUT)); + return handleCallbackResult(cachedIdpInfo, result); + } else { + // cache is empty + if (fallbackState != FallbackState.PHASE_3A_PRINCIPAL && challenge.length == 0) { + fallbackState = FallbackState.PHASE_3A_PRINCIPAL; + return prepareUsername(mongoCredentialWithCache.getCredential().getUserName()); + } else { + fallbackState = FallbackState.PHASE_3B_REQUEST_CALLBACK_TOKEN; + IdpInfo idpInfo = toIdpInfo(challenge); + IdpResponse result = invokeRequestCallback(requestCallback, idpInfo); + return handleCallbackResult(idpInfo, result); + } + } + } + + private boolean clientIsComplete() { + return fallbackState != FallbackState.PHASE_3A_PRINCIPAL; + } + + private void prepareRetry(final MongoException e, final OidcCacheEntry cacheEntry) { + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { + // a cached access token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken()); + } else if (fallbackState == FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN) { + // a refresh token failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + } else { + // a clean-restart failed + mongoCredentialWithCache.setOidcCacheEntry(cacheEntry + .clearAccessToken() + .clearRefreshToken()); + throw e; + } + } + + @Nullable + private String getValidCachedAccessToken() { + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + String cachedAccessToken = cacheEntry.accessToken; + if (cachedAccessToken == null) { + return null; + } + if (cacheEntry.isExpired()) { + return mongoCredentialWithCache.withOidcLock(() -> { + OidcCacheEntry mostRecentCacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); + if (mostRecentCacheEntry.isExpired()) { + mongoCredentialWithCache.setOidcCacheEntry(mostRecentCacheEntry.clearAccessToken()); + return null; + } else { + return mostRecentCacheEntry.accessToken; + } + }); + } + return cachedAccessToken; + } + + static final class OidcCacheEntry { + @Nullable + private final String accessToken; + @Nullable + private final Instant expiry; + @Nullable + private final String refreshToken; + @Nullable + private final IdpInfo idpInfo; + + @Override + public String toString() { + return "OidcCacheEntry{" + + "\n accessToken#hashCode='" + (accessToken == null ? null : accessToken.hashCode()) + '\'' + + ",\n expiry=" + expiry + + ",\n refreshToken='" + refreshToken + '\'' + + ",\n idpInfo=" + idpInfo + + '}'; + } + + OidcCacheEntry(@Nullable final IdpInfo idpInfo, final IdpResponse idpResponse) { + Integer expiresInSeconds = idpResponse.getExpiresInSeconds(); + if (expiresInSeconds != null) { + final Instant expiry = Instant.now().plusSeconds(expiresInSeconds) + .minus(5, ChronoUnit.MINUTES); + this.accessToken = idpResponse.getAccessToken(); + this.expiry = expiry; + } else { + this.accessToken = null; + this.expiry = null; + } + String refreshToken = idpResponse.getRefreshToken(); + if (refreshToken != null) { + this.refreshToken = refreshToken; + this.idpInfo = idpInfo; + } else { + this.refreshToken = null; + this.idpInfo = null; + } + } + + OidcCacheEntry() { + this(null, null, null, null); + } + + private OidcCacheEntry(@Nullable final String accessToken, @Nullable final Instant expiry, + @Nullable final String refreshToken, @Nullable final IdpInfo idpInfo) { + this.accessToken = accessToken; + this.expiry = expiry; + this.refreshToken = refreshToken; + this.idpInfo = idpInfo; + } + + public boolean isExpired() { + return expiry == null || Instant.now().isAfter(expiry); + } + + public OidcCacheEntry clearAccessToken() { + return new OidcCacheEntry( + null, + null, + this.refreshToken, + this.idpInfo); + } + + public OidcCacheEntry clearRefreshToken() { + return new OidcCacheEntry( + this.accessToken, + this.expiry, + null, + null); + } + } + + private final class OidcSaslClient extends SaslClientImpl { + + private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache) { + super(mongoCredentialWithCache.getCredential()); + } + + @Override + public byte[] evaluateChallenge(final byte[] challenge) { + return evaluateChallengeInternal(challenge); + } + + @Override + public boolean isComplete() { + return clientIsComplete(); + } + + public byte[] evaluateChallengeInternal(final byte[] challenge) { + return evaluateChallengeFunction.apply(challenge); + } + } + + private static byte[] prepareAwsTokenFromFile() { + return toBson(new BsonDocument() + .append("jwt", new BsonString(readAwsTokenFromFile()))); + } + + private static String readAwsTokenFromFile() { + String path = System.getenv(AWS_WEB_IDENTITY_TOKEN_FILE); + if (path == null) { + throw new MongoClientException( + format("Environment variable must be specified: %s", AWS_WEB_IDENTITY_TOKEN_FILE)); + } + try { + return new String(Files.readAllBytes(Paths.get(path)), StandardCharsets.UTF_8); + } catch (IOException e) { + throw new MongoClientException(format( + "Could not read file specified by environment variable: %s at path: %s", + AWS_WEB_IDENTITY_TOKEN_FILE, path), e); + } + } + + private byte[] prepareUsername(@Nullable final String username) { + BsonDocument document = new BsonDocument(); + if (username != null) { + document = document.append("n", new BsonString(username)); + } + return toBson(document); + } + + private byte[] handleCallbackResult( + final IdpInfo serverInfo, + @Nullable final IdpResponse tokens) { + if (tokens == null) { + throw new MongoConfigurationException("Result of callback must not be null"); + } + OidcCacheEntry newEntry = new OidcCacheEntry(serverInfo, tokens); + getMongoCredentialWithCache().setOidcCacheEntry(newEntry); + return prepareTokenAsJwt(tokens.getAccessToken()); + } + + private IdpInfo toIdpInfo(final byte[] challenge) { + BsonDocument c = new RawBsonDocument(challenge); + String issuer = c.getString("issuer").getValue(); + String clientId = c.getString("clientId").getValue(); + return new IdpInfoImpl( + issuer, + clientId, + getStringArray(c, "requestScopes")); + } + + private IdpResponse invokeRequestCallback(final OidcRequestCallback requestCallback, + final IdpInfo serverInfo) { + validateAllowedHosts(getMongoCredential()); + return requestCallback.onRequest(new OidcRequestContextImpl(serverInfo, CALLBACK_TIMEOUT)); + } + + private void validateAllowedHosts(final MongoCredential credential) { + List allowedHosts = assertNotNull(credential.getMechanismProperty(ALLOWED_HOSTS_KEY, DEFAULT_ALLOWED_HOSTS)); + String host = serverAddress.getHost(); + boolean permitted = allowedHosts.stream().anyMatch(allowedHost -> { + if (allowedHost.startsWith("*.")) { + String ending = allowedHost.substring(1); + return host.endsWith(ending); + } else if (allowedHost.contains("*")) { + throw new IllegalArgumentException( + "Allowed host " + allowedHost + " contains invalid wildcard"); + } else { + return host.equals(allowedHost); + } + }); + if (!permitted) { + throw new MongoSecurityException( + credential, "Host not permitted by " + ALLOWED_HOSTS_KEY + ": " + host); + } + } + + @Nullable + private List getStringArray(final BsonDocument document, final String key) { + if (!document.containsKey(key) || document.isArray(key)) { + return null; + } + List result = document.getArray(key).getValues().stream() + // ignore non-string values from server, rather than error + .filter(v -> v.isString()) + .map(v -> v.asString().getValue()) + .collect(Collectors.toList()); + return Collections.unmodifiableList(result); + } + + private byte[] prepareTokenAsJwt(final String accessToken) { + connectionLastAccessToken = accessToken; + return toBson(new BsonDocument().append("jwt", new BsonString(accessToken))); + } + + private static byte[] toBson(final BsonDocument document) { + BasicOutputBuffer buffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); + byte[] bytes = new byte[buffer.size()]; + System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); + return bytes; + } + + public static final class OidcValidator { + private OidcValidator() { + } + + public static void validateOidcCredentialConstruction( + final String source, + final Map mechanismProperties) { + + if (!"$external".equals(source)) { + throw new IllegalArgumentException("source must be '$external'"); + } + + Object providerName = mechanismProperties.get(PROVIDER_NAME_KEY.toLowerCase()); + if (providerName != null) { + if (!(providerName instanceof String) || !SUPPORTED_PROVIDERS.contains(providerName)) { + throw new IllegalArgumentException(PROVIDER_NAME_KEY + " must be one of: " + SUPPORTED_PROVIDERS); + } + } + } + + public static void validateCreateOidcCredential(@Nullable final char[] password) { + if (password != null) { + throw new IllegalArgumentException("password must not be specified for " + + AuthenticationMechanism.MONGODB_OIDC); + } + } + + public static void validateBeforeUse(final MongoCredential credential) { + AuthenticationMechanism mechanism = credential.getAuthenticationMechanism(); + String userName = credential.getUserName(); + + if (mechanism == AuthenticationMechanism.MONGODB_OIDC) { + Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); + Object requestCallback = credential.getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); + Object refreshCallback = credential.getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); + if (providerName == null) { + // callback + if (requestCallback == null) { + throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + " or " + + REQUEST_TOKEN_CALLBACK_KEY + " must be specified"); + } + } else { + // automatic + if (userName != null) { + throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + if (requestCallback != null) { + throw new IllegalArgumentException(REQUEST_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + if (refreshCallback != null) { + throw new IllegalArgumentException(REFRESH_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + } + } + } + } + + + public static class OidcRequestContextImpl implements OidcRequestContext { + private final IdpInfo idpInfo; + private final Duration timeout; + + public OidcRequestContextImpl(final IdpInfo idpInfo, final Duration timeout) { + notNull("idpInfo", idpInfo); + notNull("timeout", timeout); + this.idpInfo = idpInfo; + this.timeout = timeout; + } + + public IdpInfo getIdpInfo() { + return idpInfo; + } + + public Duration getTimeout() { + return timeout; + } + } + + public static final class OidcRefreshContextImpl extends OidcRequestContextImpl + implements OidcRefreshContext { + private final String refreshToken; + + public OidcRefreshContextImpl(final IdpInfo idpInfo, final String refreshToken, + final Duration timeout) { + super(idpInfo, timeout); + notNull("refreshToken", refreshToken); + this.refreshToken = refreshToken; + } + + public String getRefreshToken() { + return refreshToken; + } + } + + public static final class IdpInfoImpl implements IdpInfo { + private final String issuer; + private final String clientId; + + private final List requestScopes; + + public IdpInfoImpl(final String issuer, final String clientId, @Nullable final List requestScopes) { + this.issuer = issuer; + this.clientId = clientId; + this.requestScopes = requestScopes == null + ? Collections.emptyList() + : Collections.unmodifiableList(requestScopes); + } + + public String getIssuer() { + return issuer; + } + + public String getClientId() { + return clientId; + } + + public List getRequestScopes() { + return requestScopes; + } + } + + /** + * Represents what was sent in the last request to the MongoDB server. + */ + enum FallbackState { + INITIAL, + PHASE_1_CACHED_TOKEN, + PHASE_2_REFRESH_CALLBACK_TOKEN, + PHASE_3A_PRINCIPAL, + PHASE_3B_REQUEST_CALLBACK_TOKEN + } +} diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index ebcf81c8532..d1298be571e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -16,6 +16,8 @@ package com.mongodb.internal.connection; +import com.mongodb.AuthenticationMechanism; +import com.mongodb.MongoCredential; import com.mongodb.MongoException; import com.mongodb.MongoInterruptedException; import com.mongodb.MongoSecurityException; @@ -55,6 +57,7 @@ abstract class SaslAuthenticator extends Authenticator implements SpeculativeAut super(credential, clusterConnectionMode, serverApi); } + @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { doAsSubject(() -> { SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress()); @@ -121,9 +124,11 @@ private void throwIfSaslClientIsNull(@Nullable final SaslClient saslClient) { } private BsonDocument getNextSaslResponse(final SaslClient saslClient, final InternalConnection connection) { - BsonDocument response = getSpeculativeAuthenticateResponse(); - if (response != null) { - return response; + if (!connection.opened()) { + BsonDocument response = getSpeculativeAuthenticateResponse(); + if (response != null) { + return response; + } } try { @@ -136,9 +141,9 @@ private BsonDocument getNextSaslResponse(final SaslClient saslClient, final Inte private void getNextSaslResponseAsync(final SaslClient saslClient, final InternalConnection connection, final SingleResultCallback callback) { - BsonDocument response = getSpeculativeAuthenticateResponse(); - SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); try { + BsonDocument response = getSpeculativeAuthenticateResponse(); + SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); if (response == null) { byte[] serverResponse = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null); sendSaslStartAsync(serverResponse, connection, (result, t) -> { @@ -331,8 +336,52 @@ private void continueConversation(final BsonDocument result) { disposeOfSaslClient(saslClient); } } - } + protected abstract static class SaslClientImpl implements SaslClient { + private final MongoCredential credential; + + public SaslClientImpl(final MongoCredential credential) { + this.credential = credential; + } + + @Override + public boolean hasInitialResponse() { + return true; + } + + @Override + public byte[] unwrap(final byte[] bytes, final int i, final int i1) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public byte[] wrap(final byte[] bytes, final int i, final int i1) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public Object getNegotiatedProperty(final String s) { + throw new UnsupportedOperationException("Not implemented."); + } + + @Override + public void dispose() { + // nothing to do + } + + @Override + public final String getMechanismName() { + AuthenticationMechanism authMechanism = getCredential().getAuthenticationMechanism(); + if (authMechanism == null) { + throw new IllegalArgumentException("Authentication mechanism cannot be null"); + } + return authMechanism.getMechanismName(); + } + + protected MongoCredential getCredential() { + return credential; + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java index 5dec0d90c1e..02bc7912c93 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ScramShaAuthenticator.java @@ -92,7 +92,7 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { if (speculativeSaslClient != null) { return speculativeSaslClient; } - return new ScramShaSaslClient(getMongoCredentialWithCache(), randomStringGenerator, authenticationHashGenerator); + return new ScramShaSaslClient(getMongoCredentialWithCache().getCredential(), randomStringGenerator, authenticationHashGenerator); } @Override @@ -122,9 +122,7 @@ public void setSpeculativeAuthenticateResponse(@Nullable final BsonDocument resp } } - class ScramShaSaslClient implements SaslClient { - - private final MongoCredentialWithCache credential; + class ScramShaSaslClient extends SaslClientImpl { private final RandomStringGenerator randomStringGenerator; private final AuthenticationHashGenerator authenticationHashGenerator; private final String hAlgorithm; @@ -136,9 +134,11 @@ class ScramShaSaslClient implements SaslClient { private byte[] serverSignature; private int step = -1; - ScramShaSaslClient(final MongoCredentialWithCache credential, final RandomStringGenerator randomStringGenerator, - final AuthenticationHashGenerator authenticationHashGenerator) { - this.credential = credential; + ScramShaSaslClient( + final MongoCredential credential, + final RandomStringGenerator randomStringGenerator, + final AuthenticationHashGenerator authenticationHashGenerator) { + super(credential); this.randomStringGenerator = randomStringGenerator; this.authenticationHashGenerator = authenticationHashGenerator; if (assertNotNull(credential.getAuthenticationMechanism()).equals(SCRAM_SHA_1)) { @@ -150,14 +150,6 @@ class ScramShaSaslClient implements SaslClient { } } - public String getMechanismName() { - return assertNotNull(credential.getAuthenticationMechanism()).getMechanismName(); - } - - public boolean hasInitialResponse() { - return true; - } - public byte[] evaluateChallenge(final byte[] challenge) throws SaslException { step++; if (step == 0) { @@ -167,7 +159,8 @@ public byte[] evaluateChallenge(final byte[] challenge) throws SaslException { } else if (step == 2) { return validateServerSignature(challenge); } else { - throw new SaslException(format("Too many steps involved in the %s negotiation.", getMechanismName())); + throw new SaslException(format("Too many steps involved in the %s negotiation.", + super.getMechanismName())); } } @@ -184,22 +177,6 @@ public boolean isComplete() { return step == 2; } - public byte[] unwrap(final byte[] incoming, final int offset, final int len) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public byte[] wrap(final byte[] outgoing, final int offset, final int len) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public Object getNegotiatedProperty(final String propName) { - throw new UnsupportedOperationException("Not implemented yet!"); - } - - public void dispose() { - // nothing to do - } - private byte[] computeClientFirstMessage() { clientNonce = randomStringGenerator.generate(RANDOM_LENGTH); String clientFirstMessage = "n=" + getUserName() + ",r=" + clientNonce; @@ -318,9 +295,8 @@ private HashMap parseServerResponse(final String response) { return map; } - private String getUserName() { - String userName = credential.getCredential().getUserName(); + String userName = getCredential().getUserName(); if (userName == null) { throw new IllegalArgumentException("Username can not be null"); } @@ -328,8 +304,8 @@ private String getUserName() { } private String getAuthenicationHash() { - String password = authenticationHashGenerator.generate(credential.getCredential()); - if (credential.getAuthenticationMechanism() == SCRAM_SHA_256) { + String password = authenticationHashGenerator.generate(getCredential()); + if (getCredential().getAuthenticationMechanism() == SCRAM_SHA_256) { password = SaslPrep.saslPrepStored(password); } return password; diff --git a/driver-core/src/test/functional/com/mongodb/client/TestHelper.java b/driver-core/src/test/functional/com/mongodb/client/TestHelper.java new file mode 100644 index 00000000000..237c03c7e19 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/TestHelper.java @@ -0,0 +1,47 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.lang.Nullable; + +import java.lang.reflect.Field; +import java.util.Map; + +import static java.lang.System.getenv; + +public final class TestHelper { + + public static void setEnvironmentVariable(final String name, @Nullable final String value) { + try { + Map env = getenv(); + Field field = env.getClass().getDeclaredField("m"); + field.setAccessible(true); + @SuppressWarnings("unchecked") + Map result = (Map) field.get(env); + if (value == null) { + result.remove(name); + } else { + result.put(name, value); + } + } catch (IllegalAccessException | NoSuchFieldException e) { + throw new RuntimeException(e); + } + } + + private TestHelper() { + } +} diff --git a/driver-core/src/test/functional/com/mongodb/client/TestListener.java b/driver-core/src/test/functional/com/mongodb/client/TestListener.java new file mode 100644 index 00000000000..e014d98af19 --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/client/TestListener.java @@ -0,0 +1,40 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A simple listener that consumes string events, which can be checked in tests. + */ +public class TestListener { + private final List events = Collections.synchronizedList(new ArrayList<>()); + + public void add(final String s) { + events.add(s); + } + + public List getEventStrings() { + return new ArrayList<>(events); + } + + public void clear() { + events.clear(); + } +} diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java index fcdbeccc420..c520b619fa4 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java @@ -18,11 +18,13 @@ import com.mongodb.MongoInterruptedException; import com.mongodb.MongoTimeoutException; +import com.mongodb.client.TestListener; import com.mongodb.event.CommandEvent; import com.mongodb.event.CommandFailedEvent; import com.mongodb.event.CommandListener; import com.mongodb.event.CommandStartedEvent; import com.mongodb.event.CommandSucceededEvent; +import com.mongodb.lang.Nullable; import org.bson.BsonDocument; import org.bson.BsonDocumentWriter; import org.bson.BsonDouble; @@ -55,6 +57,8 @@ public class TestCommandListener implements CommandListener { private final List eventTypes; private final List ignoredCommandMonitoringEvents; private final List events = new ArrayList<>(); + @Nullable + private volatile TestListener listener = null; private final Lock lock = new ReentrantLock(); private final Condition commandCompletedCondition = lock.newCondition(); private final boolean observeSensitiveCommands; @@ -91,10 +95,26 @@ public TestCommandListener(final List eventTypes, final List ign this.observeSensitiveCommands = observeSensitiveCommands; } + /** + * When this is set, this command listener will send string events to the + * listener in the form {@code " "}, where the event + * type will be lowercase and will omit the terms "command" and "event". + * For example: {@code "saslContinue succeeded"}. + * + * @see InternalStreamConnection#setRecordEverything(boolean) + * @param eventStrings the test listener + */ + public void setEventStrings(final TestListener eventStrings) { + this.listener = eventStrings; + } + public void reset() { lock.lock(); try { events.clear(); + if (listener != null) { + listener.add("CommandListener reset"); + } } finally { lock.unlock(); } @@ -109,6 +129,18 @@ public List getEvents() { } } + private void addEvent(final CommandEvent c) { + events.add(c); + if (listener != null) { + String className = c.getClass().getSimpleName() + .replace("Command", "") + .replace("Event", "") + .toLowerCase(); + // example: "saslContinue succeeded" + listener.add(c.getCommandName() + " " + className); + } + } + public CommandStartedEvent getCommandStartedEvent(final String commandName) { for (CommandEvent event : getCommandStartedEvents()) { if (event instanceof CommandStartedEvent) { @@ -226,7 +258,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(new CommandStartedEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), + addEvent(new CommandStartedEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), event.getConnectionDescription(), event.getDatabaseName(), event.getCommandName(), event.getCommand() == null ? null : getWritableClone(event.getCommand()))); } finally { @@ -249,7 +281,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(new CommandSucceededEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), + addEvent(new CommandSucceededEvent(event.getRequestContext(), event.getOperationId(), event.getRequestId(), event.getConnectionDescription(), event.getCommandName(), event.getResponse() == null ? null : event.getResponse().clone(), event.getElapsedTime(TimeUnit.NANOSECONDS))); @@ -274,7 +306,7 @@ else if (!observeSensitiveCommands) { } lock.lock(); try { - events.add(event); + addEvent(event); commandCompletedCondition.signal(); } finally { lock.unlock(); diff --git a/driver-core/src/test/resources/auth/legacy/connection-string.json b/driver-core/src/test/resources/auth/legacy/connection-string.json index 3aa0ae2dc4d..1d69685df10 100644 --- a/driver-core/src/test/resources/auth/legacy/connection-string.json +++ b/driver-core/src/test/resources/auth/legacy/connection-string.json @@ -566,6 +566,20 @@ "valid": false, "credential": null }, + { + "description": "should throw an exception if provider name and request callback are specified", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "callback": ["oidcRequest"], + "valid": false, + "credential": null + }, + { + "description": "should throw an exception if provider name and refresh callback are specified", + "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws", + "callback": ["oidcRefresh"], + "valid": false, + "credential": null + }, { "description": "should throw an exception when unsupported auth property is specified (MONGODB-OIDC)", "uri": "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=UnsupportedProperty:unexisted", diff --git a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java index dfb81ba8de4..019b8d15a4b 100644 --- a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java +++ b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java @@ -16,9 +16,13 @@ package com.mongodb; +import com.mongodb.internal.connection.OidcAuthenticator; +import com.mongodb.lang.Nullable; import junit.framework.TestCase; +import org.bson.BsonArray; import org.bson.BsonDocument; import org.bson.BsonNull; +import org.bson.BsonString; import org.bson.BsonValue; import org.junit.Test; import org.junit.runner.RunWith; @@ -32,7 +36,10 @@ import java.util.Collection; import java.util.List; -// See https://github.com/mongodb/specifications/tree/master/source/auth/tests +import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; + +// See https://github.com/mongodb/specifications/tree/master/source/auth/legacy/tests @RunWith(Parameterized.class) public class AuthConnectionStringTest extends TestCase { private final String input; @@ -56,7 +63,7 @@ public void shouldPassAllOutcomes() { @Parameterized.Parameters(name = "{1}") public static Collection data() throws URISyntaxException, IOException { List data = new ArrayList<>(); - for (File file : JsonPoweredTestHelper.getTestFiles("/auth")) { + for (File file : JsonPoweredTestHelper.getTestFiles("/auth/legacy")) { BsonDocument testDocument = JsonPoweredTestHelper.getTestDocument(file); for (BsonValue test : testDocument.getArray("tests")) { data.add(new Object[]{file.getName(), test.asDocument().getString("description").getValue(), @@ -69,7 +76,7 @@ public static Collection data() throws URISyntaxException, IOException private void testInvalidUris() { Throwable expectedError = null; try { - new ConnectionString(input).getCredential(); + getMongoCredential(); } catch (Throwable t) { expectedError = t; } @@ -78,7 +85,7 @@ private void testInvalidUris() { } private void testValidUris() { - MongoCredential credential = new ConnectionString(input).getCredential(); + MongoCredential credential = getMongoCredential(); if (credential != null) { assertString("credential.source", credential.getSource()); @@ -99,6 +106,34 @@ private void testValidUris() { } } + @Nullable + private MongoCredential getMongoCredential() { + ConnectionString connectionString; + connectionString = new ConnectionString(input); + MongoCredential credential = connectionString.getCredential(); + if (credential != null) { + BsonArray callbacks = (BsonArray) getExpectedValue("callback"); + if (callbacks != null) { + for (BsonValue v : callbacks) { + String string = ((BsonString) v).getValue(); + if ("oidcRequest".equals(string)) { + credential = credential.withMechanismProperty( + REQUEST_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRequestCallback) (context) -> null); + } else if ("oidcRefresh".equals(string)) { + credential = credential.withMechanismProperty( + REFRESH_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRefreshCallback) (context) -> null); + } else { + fail("Unsupported callback: " + string); + } + } + } + OidcAuthenticator.OidcValidator.validateBeforeUse(credential); + } + return credential; + } + private void assertString(final String key, final String actual) { BsonValue expected = getExpectedValue(key); @@ -142,6 +177,14 @@ private void assertMechanismProperties(final MongoCredential credential) { } } else if ((document.get(key).isBoolean())) { boolean expectedValue = document.getBoolean(key).getValue(); + if (REQUEST_TOKEN_CALLBACK_KEY.equals(key)) { + assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRequestCallback); + return; + } + if (REFRESH_TOKEN_CALLBACK_KEY.equals(key)) { + assertTrue(actualMechanismProperty instanceof MongoCredential.OidcRefreshCallback); + return; + } assertNotNull(actualMechanismProperty); assertEquals(expectedValue, actualMechanismProperty); } else { diff --git a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java new file mode 100644 index 00000000000..2bb6fc794f8 --- /dev/null +++ b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java @@ -0,0 +1,1011 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.client; + +import com.mongodb.ConnectionString; +import com.mongodb.MongoClientSettings; +import com.mongodb.MongoCommandException; +import com.mongodb.MongoConfigurationException; +import com.mongodb.MongoCredential; +import com.mongodb.MongoCredential.IdpResponse; +import com.mongodb.MongoCredential.OidcRefreshCallback; +import com.mongodb.MongoSecurityException; +import com.mongodb.event.CommandListener; +import com.mongodb.internal.connection.InternalStreamConnection; +import com.mongodb.internal.connection.OidcAuthenticator; +import com.mongodb.internal.connection.TestCommandListener; +import com.mongodb.lang.Nullable; +import org.bson.BsonArray; +import org.bson.BsonBoolean; +import org.bson.BsonDocument; +import org.bson.BsonInt32; +import org.bson.BsonString; +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.opentest4j.AssertionFailedError; +import org.opentest4j.MultipleFailuresError; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.util.Arrays; +import java.util.Date; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; +import static com.mongodb.MongoCredential.IdpInfo; +import static com.mongodb.MongoCredential.OidcRefreshContext; +import static com.mongodb.MongoCredential.OidcRequestContext; +import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; +import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; +import static com.mongodb.MongoCredential.OidcRequestCallback; +import static com.mongodb.MongoCredential.createOidcCredential; +import static com.mongodb.client.TestHelper.setEnvironmentVariable; +import static java.lang.System.getenv; +import static java.util.Arrays.asList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static util.ThreadTestHelpers.executeAll; + +public class OidcAuthenticationProseTests { + + private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + + public static final String TOKEN_DIRECTORY = "/tmp/tokens/"; // TODO-OIDC + + protected static final String OIDC_URL = "mongodb://localhost/?authMechanism=MONGODB-OIDC"; + private static final String AWS_OIDC_URL = + "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws"; + + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return MongoClients.create(settings); + } + + protected void setOidcFile(final String file) { + setEnvironmentVariable(AWS_WEB_IDENTITY_TOKEN_FILE, TOKEN_DIRECTORY + file); + } + + @BeforeEach + public void beforeEach() { + setOidcFile("test_user1"); + InternalStreamConnection.setRecordEverything(true); + // In each test, clearing the cache is not required, since there is no global cache + } + + @AfterEach + public void afterEach() { + InternalStreamConnection.setRecordEverything(false); + } + + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // 1.1 to 1.5: + "test1p1 # test_user1 # " + OIDC_URL, + "test1p2 # test_user1 # mongodb://test_user1@localhost/?authMechanism=MONGODB-OIDC", + "test1p3 # test_user1 # mongodb://test_user1@localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", + "test1p4 # test_user2 # mongodb://test_user2@localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", + "test1p5 # invalid # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&directConnection=true&readPreference=secondaryPreferred", + }) + public void test1CallbackDrivenAuth(final String name, final String file, final String url) { + boolean shouldPass = !file.equals("invalid"); + setOidcFile(file); + // #. Create a request callback that returns a valid token. + OidcRequestCallback onRequest = createCallback(); + // #. Create a client with a URL of the form ... and the OIDC request callback. + MongoClientSettings clientSettings = createSettings(url, onRequest, null); + // #. Perform a find operation that succeeds / fails + if (shouldPass) { + performFind(clientSettings); + } else { + performFind( + clientSettings, + MongoCommandException.class, + "Command failed with error 18 (AuthenticationFailed)"); + } + } + + // TODO-OIDC-x additional tests for token with null expiry? + + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // 1.6, both variants: + "'' # " + OIDC_URL, + "example.com # mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com", + }) + public void test1p6CallbackDrivenAuthAllowedHostsBlocked(final String allowedHosts, final String url) { + // Create a client that uses the OIDC url and a request callback, and an ALLOWED_HOSTS that contains... + List allowedHostsList = asList(allowedHosts.split(",")); + MongoClientSettings settings = createSettings(url, createCallback(), null, allowedHostsList, null); + // #. Assert that a find operation fails with a client-side error. + performFind(settings, MongoSecurityException.class, ""); + } + + @Test + public void test1p7LockAvoidsExtraCallbackCalls() { + proveThatConcurrentCallbacksThrow(); + // The test requires that two operations are attempted concurrently. + // The delay on the next find should cause the initial request to delay + // and the ensuing refresh to block, rather than entering onRefresh. + // After blocking, this ensuing refresh thread will enter onRefresh. + AtomicInteger concurrent = new AtomicInteger(); + TestCallback onRequest = createCallback().setExpired().setConcurrentTracker(concurrent); + TestCallback onRefresh = createCallback().setConcurrentTracker(concurrent); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + } + } + + public void proveThatConcurrentCallbacksThrow() { + // ensure that, via delay, test callbacks throw when invoked concurrently + AtomicInteger c = new AtomicInteger(); + TestCallback request = createCallback().setConcurrentTracker(c).setDelayMs(5); + TestCallback refresh = createCallback().setConcurrentTracker(c); + IdpInfo serverInfo = new OidcAuthenticator.IdpInfoImpl("issuer", "clientId", asList()); + executeAll(() -> { + sleep(2); + assertThrows(RuntimeException.class, () -> { + refresh.onRefresh(new OidcAuthenticator.OidcRefreshContextImpl(serverInfo, "refToken", Duration.ofSeconds(1234))); + }); + }, () -> { + request.onRequest(new OidcAuthenticator.OidcRequestContextImpl(serverInfo, Duration.ofSeconds(1234))); + }); + } + + private void sleep(final long ms) { + try { + Thread.sleep(ms); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + @ParameterizedTest + @CsvSource(delimiter = '#', value = { + // 2.1 to 2.3: + "test2p1 # test_user1 # " + AWS_OIDC_URL, + "test2p2 # test_user1 # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws&directConnection=true&readPreference=secondaryPreferred", + "test2p3 # test_user2 # mongodb://localhost:27018/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws&directConnection=true&readPreference=secondaryPreferred", + }) + public void test2AwsAutomaticAuth(final String name, final String file, final String url) { + setOidcFile(file); + // #. Create a client with a url of the form ... + MongoCredential credential = createOidcCredential(null) + .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); + MongoClientSettings clientSettings = MongoClientSettings.builder() + .credential(credential) + .applyConnectionString(new ConnectionString(url)) + .build(); + // #. Perform a find operation that succeeds. + performFind(clientSettings); + } + + @Test + public void test2p4AllowedHostsIgnored() { + MongoClientSettings settings = createSettings( + AWS_OIDC_URL, null, null, Arrays.asList(), null); + performFind(settings); + } + + @Test + public void test3p1ValidCallbacks() { + String connectionString = "mongodb://test_user1@localhost/?authMechanism=MONGODB-OIDC"; + String expectedClientId = "0oadp0hpl7q3UIehP297"; + String expectedIssuer = "https://ebgxby0dw8.execute-api.us-west-1.amazonaws.com/default/mock-identity-config-oidc"; + Duration expectedSeconds = Duration.ofMinutes(5); + + TestCallback onRequest = createCallback().setExpired(); + TestCallback onRefresh = createCallback(); + // #. Verify that the request callback was called with the appropriate + // inputs, including the timeout parameter if possible. + // #. Verify that the refresh callback was called with the appropriate + // inputs, including the timeout parameter if possible. + OidcRequestCallback onRequest2 = (context) -> { + assertEquals(expectedClientId, context.getIdpInfo().getClientId()); + assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); + //assertEquals(Arrays.asList(""), serverInfo.getRequestScopes()); // TODO-OIDC-x fix when docker updated + assertEquals(expectedSeconds, context.getTimeout()); + return onRequest.onRequest(context); + }; + OidcRefreshCallback onRefresh2 = (context) -> { + assertEquals(expectedClientId, context.getIdpInfo().getClientId()); + assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); + //assertEquals(Arrays.asList(""), serverInfo.getRequestScopes()); // TODO-OIDC-x fix when docker updated + assertEquals(expectedSeconds, context.getTimeout()); + assertEquals("refreshToken", context.getRefreshToken()); + return onRefresh.onRefresh(context); + }; + MongoClientSettings clientSettings = createSettings(connectionString, onRequest2, onRefresh2); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + // Ensure that both callbacks were called + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + } + } + + @Test + public void test3p2RequestCallbackReturnsNull() { + //noinspection ConstantConditions + OidcRequestCallback onRequest = (context) -> null; + MongoClientSettings settings = this.createSettings(OIDC_URL, onRequest, null); + performFind(settings, MongoConfigurationException.class, "Result of callback must not be null"); + } + + @Test + public void test3p3RefreshCallbackReturnsNull() { + TestCallback onRequest = createCallback().setExpired().setDelayMs(100); + //noinspection ConstantConditions + OidcRefreshCallback onRefresh = (context) -> null; + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + try { + executeAll(2, () -> performFind(mongoClient)); + } catch (MultipleFailuresError actual) { + assertEquals(1, actual.getFailures().size()); + assertCause( + MongoConfigurationException.class, + "Result of callback must not be null", + actual.getFailures().get(0)); + } + assertEquals(1, onRequest.getInvocations()); + } + } + + @Test + public void test3p4RequestCallbackReturnsInvalidData() { + // #. Create a client with a request callback that returns data not + // conforming to the OIDCRequestTokenResult with missing field(s). + // #. ... with extra field(s). - not possible + OidcRequestCallback onRequest = (context) -> { + //noinspection ConstantConditions + return new IdpResponse(null, null, null); + }; + // we ensure that the error is propagated + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + try { + performFind(mongoClient); + fail(); + } catch (Exception e) { + assertCause(IllegalArgumentException.class, "accessToken can not be null", e); + } + } + } + + @Test + public void test3p5RefreshCallbackReturnsInvalidData() { + TestCallback onRequest = createCallback().setExpired(); + OidcRefreshCallback onRefresh = (context) -> { + //noinspection ConstantConditions + return new IdpResponse(null, null, null); + }; + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + try { + executeAll(2, () -> performFind(mongoClient)); + } catch (MultipleFailuresError actual) { + assertEquals(1, actual.getFailures().size()); + assertCause( + IllegalArgumentException.class, + "accessToken can not be null", + actual.getFailures().get(0)); + } + assertEquals(1, onRequest.getInvocations()); + } + } + + // 3.6 Refresh Callback Returns Extra Data - not possible due to use of class + + @Test + public void test4p1CachedCredentialsCacheWithRefresh() { + // #. Create a new client with a request callback that gives credentials that expire in one minute. + TestCallback onRequest = createCallback().setExpired(); + TestCallback onRefresh = createCallback(); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Create a new client with the same request callback and a refresh callback. + // Instead: + // 1. Delay the first find, causing the second find to authenticate a second connection + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + // #. Ensure that a find operation adds credentials to the cache. + // #. Ensure that a find operation results in a call to the refresh callback. + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + // the refresh invocation will fail if the cached tokens are null + // so a success implies that credentials were present in the cache + } + } + + @Test + public void test4p2CachedCredentialsCacheWithNoRefresh() { + // #. Create a new client with a request callback that gives credentials that expire in one minute. + // #. Ensure that a find operation adds credentials to the cache. + // #. Create a new client with a request callback but no refresh callback. + // #. Ensure that a find operation results in a call to the request callback. + TestCallback onRequest = createCallback().setExpired(); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + delayNextFind(); // cause both callbacks to be called + executeAll(2, () -> performFind(mongoClient)); + // test is the same as 4.1, but no onRefresh, and assert that the onRequest is called twice + assertEquals(2, onRequest.getInvocations()); + } + } + + // 4.3 Cache key includes callback - skipped: + // If the driver does not support using callback references or hashes as part of the cache key, skip this test. + + @Test + public void test4p4ErrorClearsCache() { + // #. Create a new client with a valid request callback that + // gives credentials that expire within 5 minutes and + // a refresh callback that gives invalid credentials. + + TestListener listener = new TestListener(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires", + "test_user1_1"); + TestCallback onRequest = createCallback() + .setExpired() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + TestCallback onRefresh = createCallback() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + + TestCommandListener commandListener = new TestCommandListener(); + commandListener.setEventStrings(listener); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Ensure that a find operation adds a new entry to the cache. + performFind(mongoClient); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "onRequest invoked", + "read access token: test_user1", + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that a subsequent find operation results in a 391 error. + failCommand(391, 1, "find"); + // ensure that the operation entirely fails, after attempting both potential fallback callbacks + assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); + assertEquals(Arrays.asList( + "find started", + "find failed", + "onRefresh invoked", + "read access token: test_user1_expires", + "saslStart started", + "saslStart failed", + // falling back to principal request, onRequest callback. + "saslStart started", + "saslStart succeeded", + "onRequest invoked", + "read access token: test_user1_expires", + "saslContinue started", + "saslContinue failed" + ), listener.getEventStrings()); + listener.clear(); + + // #. Ensure that the cache value cleared. + failCommand(391, 1, "find"); + performFind(mongoClient); + assertEquals(Arrays.asList( + "find started", + "find failed", + // falling back to principal request, onRequest callback. + // this implies that the cache has been cleared during the + // preceding find operation. + "saslStart started", + "saslStart succeeded", + "onRequest invoked", + "read access token: test_user1_1", + "saslContinue started", + "saslContinue succeeded", + // auth has finished + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + } + } + + // not a prose test. + @Test + public void testEventListenerMustNotLogReauthentication() { + InternalStreamConnection.setRecordEverything(false); + + TestListener listener = new TestListener(); + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_expires", + "test_user1_expires", + "test_user1_1"); + TestCallback onRequest = createCallback() + .setExpired() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + TestCallback onRefresh = createCallback() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(listener); + + TestCommandListener commandListener = new TestCommandListener(); + commandListener.setEventStrings(listener); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + assertEquals(Arrays.asList( + "onRequest invoked", + "read access token: test_user1", + "find started", + "find succeeded" + ), listener.getEventStrings()); + listener.clear(); + + failCommand(391, 1, "find"); + assertThrows(MongoSecurityException.class, () -> performFind(mongoClient)); + assertEquals(Arrays.asList( + "find started", + "find failed", + "onRefresh invoked", + "read access token: test_user1_expires", + // falling back to principal request, onRequest callback + "onRequest invoked", + "read access token: test_user1_expires" + ), listener.getEventStrings()); + } + } + + @Test + public void test4p5AwsAutomaticWorkflowDoesNotUseCache() { + // #. Create a new client that uses the AWS automatic workflow. + // #. Ensure that a find operation does not add credentials to the cache. + setOidcFile("test_user1"); + MongoCredential credential = createOidcCredential(null) + .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); + ConnectionString connectionString = new ConnectionString(AWS_OIDC_URL); + MongoClientSettings clientSettings = MongoClientSettings.builder() + .credential(credential) + .applyConnectionString(connectionString) + .build(); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + performFind(mongoClient); + // This ensures that the next find failure results in a file (rather than cache) read + failCommand(391, 1, "find"); + setOidcFile("invalid_file"); + assertCause(NoSuchFileException.class, "invalid_file", () -> performFind(mongoClient)); + } + } + + @Test + public void test5SpeculativeAuthentication() { + // #. We can only test the successful case, by verifying that saslStart is not called. + // #. Create a client with a request callback that returns a valid token that will not expire soon. + TestListener listener = new TestListener(); + TestCallback onRequest = createCallback().setEventListener(listener); + TestCommandListener commandListener = new TestCommandListener(); + commandListener.setEventStrings(listener); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // instead of setting failpoints for saslStart, we inspect events + delayNextFind(); + executeAll(2, () -> performFind(mongoClient)); + + List events = listener.getEventStrings(); + assertFalse(events.stream().anyMatch(e -> e.contains("saslStart"))); + // onRequest is 2-step, so we expect 2 continues + assertEquals(2, events.stream().filter(e -> e.contains("saslContinue started")).count()); + // confirm all commands are enabled + assertTrue(events.stream().anyMatch(e -> e.contains("isMaster started"))); + } + } + + // Not a prose test + @Test + public void testAutomaticAuthUsesSpeculative() { + TestListener listener = new TestListener(); + TestCommandListener commandListener = new TestCommandListener(); + commandListener.setEventStrings(listener); + + MongoClientSettings settings = createSettings( + AWS_OIDC_URL, null, null, Arrays.asList(), commandListener); + try (MongoClient mongoClient = createMongoClient(settings)) { + // we use a listener instead of a failpoint + performFind(mongoClient); + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + } + } + + @Test + public void test6p1ReauthenticationSucceeds() { + // #. Create request and refresh callbacks that return valid credentials that will not expire soon. + TestListener listener = new TestListener(); + TestCallback onRequest = createCallback().setEventListener(listener); + TestCallback onRefresh = createCallback().setEventListener(listener); + + // #. Create a client with the callbacks and an event listener capable of listening for SASL commands. + TestCommandListener commandListener = new TestCommandListener(); + commandListener.setEventStrings(listener); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + + // #. Perform a find operation that succeeds. + performFind(mongoClient); + + // #. Assert that the refresh callback has not been called. + assertEquals(0, onRefresh.getInvocations()); + + assertEquals(Arrays.asList( + "isMaster started", + "isMaster succeeded", + "onRequest invoked", + "read access token: test_user1", + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + + // #. Clear the listener state if possible. + commandListener.reset(); + listener.clear(); + + // #. Force a reauthenication using a failCommand + failCommand(391, 1, "find"); + + // #. Perform another find operation that succeeds. + performFind(mongoClient); + + // #. Assert that the ordering of command started events is: find, find. + // #. Assert that the ordering of command succeeded events is: find. + // #. Assert that a find operation failed once during the command execution. + assertEquals(Arrays.asList( + "find started", + "find failed", + "onRefresh invoked", + "read access token: test_user1", + "saslStart started", + "saslStart succeeded", + "find started", + "find succeeded" + ), listener.getEventStrings()); + + // #. Assert that the refresh callback has been called once, if possible. + assertEquals(1, onRefresh.getInvocations()); + } + } + + //@Test // TODO-OIDC-x ignore this. not a prose test; will need to be updated after oidc spec changes + public void testFullReath() { + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", // read during initial population + "test_user1_1", // read when first thread clears cache + "test_user1_expires", // read during onRefresh + "test_user1_2", // read during onRequest + "invalid"); + TestListener events = new TestListener() { + public void add(final String s) { + String message = new Date() + + " -- " + Thread.currentThread().getName() + + " -- " + s; + super.add(message); + } + }; + TestCallback onRequest = new TestCallback() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(events); + TestCallback onRefresh = new TestCallback() + .setPathSupplier(() -> tokens.remove()) + .setEventListener(events); + TestCommandListener commandListener = new TestCommandListener(); + commandListener.setEventStrings(events); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // Populate the cache, authenticate both connections + delayNextFind(); + executeAll(2, () -> performFind(mongoClient)); + assertEquals(1, onRequest.getInvocations()); + assertEquals(0, onRefresh.getInvocations()); + + events.clear(); + + // Now we need a thread to arrive at AUTHLOCK after a failed find, + // but the cache must contain a new credential. + // The first thread performs a failing-find that takes a long time. + // Then, the second thread starts, and immediately fails its find, + // and passes through AUTHLOCK to populate the cache. + executeAll( + () -> { + failCommand(391, 1, "find"); + performFind(mongoClient); + }, + () -> { + sleep(500); // TODO-OIDC-x less time? + //events.clear(); + events.add("retrying task started"); + failCommand(391, 1, "find"); + performFind(mongoClient); + events.add("retrying task finished"); + }); + + //events.getEventStrings().forEach(e -> System.out.println(" \"" + e + "\",")); + assertEquals(Arrays.asList( + "retrying task started", + "find started", + "find failed", + // entered 391 retry logic + "onRefresh invoked", + "read access token: test_user1_expires", + "saslStart started", + "saslStart failed", + // + "saslStart started", + "saslStart succeeded", + "onRequest invoked", + "read access token: test_user1_2", + "saslContinue started", + "saslContinue succeeded", + "find started", + "find succeeded", + "retrying task finished" + ), events.getEventStrings()); + + } + } + + @NotNull + private ConcurrentLinkedQueue tokenQueue(final String... queue) { + return Stream + .of(queue) + .map(v -> TOKEN_DIRECTORY + v) + .collect(Collectors.toCollection(ConcurrentLinkedQueue::new)); + } + + @Test + public void test6p2ReauthenticationRetriesAndSucceedsWithCache() { + // #. Create request and refresh callbacks that return valid credentials that will not expire soon. + TestCallback onRequest = createCallback(); + TestCallback onRefresh = createCallback(); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Perform a find operation that succeeds. + performFind(mongoClient); + // #. Force a reauthenication using a failCommand + failCommand(391, 2, "find", "saslStart"); + // #. Perform a find operation that succeeds. + performFind(mongoClient); + } + } + + // 6.3 Retries and Fails with no Cache + // TODO-OIDC-x appears to be untestable, since it requires 391 failure on jwt; awaiting spec changes +// @Test +// public void test6p3RetriesAndFailsWithNoCache() { +// fail(); +// } + + @Test + public void test6p4SeparateConnectionsAvoidExtraCallbackCalls() { + ConcurrentLinkedQueue tokens = tokenQueue( + "test_user1", + "test_user1_1"); + TestCallback onRequest = createCallback().setPathSupplier(() -> tokens.remove()); + TestCallback onRefresh = createCallback().setPathSupplier(() -> tokens.remove()); + MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh); + try (MongoClient mongoClient = createMongoClient(clientSettings)) { + // #. Peform a find operation on each ... that succeeds. + delayNextFind(); + executeAll(2, () -> performFind(mongoClient)); + // #. Ensure that the request callback has been called once and the refresh callback has not been called. + assertEquals(1, onRequest.getInvocations()); + assertEquals(0, onRefresh.getInvocations()); + + failCommand(391, 2, "find"); + executeAll(2, () -> performFind(mongoClient)); + + // #. Ensure that the request callback has been called once and the refresh callback has been called once. + assertEquals(1, onRequest.getInvocations()); + assertEquals(1, onRefresh.getInvocations()); + } + } + + public MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcRequestCallback onRequest, + @Nullable final OidcRefreshCallback onRefresh) { + return createSettings(connectionString, onRequest, onRefresh, null, null); + } + + private MongoClientSettings createSettings( + final String connectionString, + @Nullable final OidcRequestCallback onRequest, + @Nullable final OidcRefreshCallback onRefresh, + @Nullable final List allowedHosts, + @Nullable final CommandListener commandListener) { + ConnectionString cs = new ConnectionString(connectionString); + MongoCredential credential = cs.getCredential() + .withMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, onRequest) + .withMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, onRefresh) + .withMechanismProperty(ALLOWED_HOSTS_KEY, allowedHosts); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applyConnectionString(cs) + .credential(credential); + if (commandListener != null) { + builder.addCommandListener(commandListener); + } + return builder.build(); + } + + private void performFind(final MongoClientSettings settings) { + try (MongoClient mongoClient = createMongoClient(settings)) { + performFind(mongoClient); + } + } + + private void performFind( + final MongoClientSettings settings, + final Class expectedExceptionOrCause, + final String expectedMessage) { + try (MongoClient mongoClient = createMongoClient(settings)) { + assertCause(expectedExceptionOrCause, expectedMessage, () -> performFind(mongoClient)); + } + } + + private void performFind(final MongoClient mongoClient) { + mongoClient + .getDatabase("test") + .getCollection("test") + .find() + .first(); + } + + private static void assertCause( + final Class expectedCause, final String expectedMessageFragment, final Executable e) { + Throwable actualException = assertThrows(Throwable.class, e); + assertCause(expectedCause, expectedMessageFragment, actualException); + } + + private static void assertCause( + final Class expectedCause, final String expectedMessageFragment, final Throwable actualException) { + Throwable cause = actualException; + while (cause.getCause() != null) { + cause = cause.getCause(); + } + if (!expectedCause.isInstance(cause)) { + throw new AssertionFailedError("Unexpected cause", actualException); + } + if (!cause.getMessage().contains(expectedMessageFragment)) { + throw new AssertionFailedError("Unexpected message", actualException); + } + } + + protected void delayNextFind() { + try (MongoClient client = createMongoClient(createSettings(AWS_OIDC_URL, null, null))) { + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(1))) + .append("data", new BsonDocument() + .append("failCommands", new BsonArray(asList(new BsonString("find")))) + .append("blockConnection", new BsonBoolean(true)) + .append("blockTimeMS", new BsonInt32(100))); + client.getDatabase("admin").runCommand(failPointDocument); + } + } + + protected void failCommand(final int code, final int times, final String... commands) { + try (MongoClient mongoClient = createMongoClient(createSettings( + AWS_OIDC_URL, null, null))) { + List list = Arrays.stream(commands).map(c -> new BsonString(c)).collect(Collectors.toList()); + BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) + .append("mode", new BsonDocument("times", new BsonInt32(times))) + .append("data", new BsonDocument() + .append("failCommands", new BsonArray(list)) + .append("errorCode", new BsonInt32(code))); + mongoClient.getDatabase("admin").runCommand(failPointDocument); + } + // TODO-OIDC-x the driver MUST either use a unique appName or explicitly remove the failCommand after the test to prevent leakage. + // .append("appName", new BsonString(appName)) + } + + public static class TestCallback implements OidcRequestCallback, OidcRefreshCallback { + private final AtomicInteger invocations = new AtomicInteger(); + @Nullable + private final Integer expiresInSeconds; + @Nullable + private final Integer delayInMilliseconds; + @Nullable + private final AtomicInteger concurrentTracker; + @Nullable + private final TestListener testListener; + @Nullable + private final Supplier pathSupplier; + + public TestCallback() { + this(60 * 60, null, new AtomicInteger(), null, null); + } + + public TestCallback( + @Nullable final Integer expiresInSeconds, + @Nullable final Integer delayInMilliseconds, + @Nullable final AtomicInteger concurrentTracker, + @Nullable final TestListener testListener, + @Nullable final Supplier pathSupplier) { + this.expiresInSeconds = expiresInSeconds; + this.delayInMilliseconds = delayInMilliseconds; + this.concurrentTracker = concurrentTracker; + this.testListener = testListener; + this.pathSupplier = pathSupplier; + } + + public int getInvocations() { + return invocations.get(); + } + + @Override + public IdpResponse onRequest(final OidcRequestContext context) { + if (testListener != null) { + testListener.add("onRequest invoked"); + } + return callback(); + } + + @Override + public IdpResponse onRefresh(final OidcRefreshContext context) { + if (context.getRefreshToken() == null) { + throw new IllegalArgumentException("refreshToken was null"); + } + if (testListener != null) { + testListener.add("onRefresh invoked"); + } + return callback(); + } + + @NotNull + private IdpResponse callback() { + if (concurrentTracker != null) { + if (concurrentTracker.get() > 0) { + throw new RuntimeException("Callbacks should not be invoked by multiple threads."); + } + concurrentTracker.incrementAndGet(); + } + try { + invocations.incrementAndGet(); + Path path = Paths.get(pathSupplier == null + ? getenv(AWS_WEB_IDENTITY_TOKEN_FILE) + : pathSupplier.get()); + String accessToken; + try { + simulateDelay(); + accessToken = new String(Files.readAllBytes(path), StandardCharsets.UTF_8); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + String refreshToken = "refreshToken"; + if (testListener != null) { + testListener.add("read access token: " + path.getFileName()); + } + return new IdpResponse( + accessToken, + expiresInSeconds, + refreshToken); + } finally { + if (concurrentTracker != null) { + concurrentTracker.decrementAndGet(); + } + } + } + + private void simulateDelay() throws InterruptedException { + if (delayInMilliseconds != null) { + Thread.sleep(delayInMilliseconds); + } + } + + public TestCallback setExpiresInSeconds(final Integer expiresInSeconds) { + return new TestCallback( + expiresInSeconds, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + + public TestCallback setDelayMs(final int milliseconds) { + return new TestCallback( + this.expiresInSeconds, + milliseconds, + this.concurrentTracker, + this.testListener, + this.pathSupplier); + } + + public TestCallback setConcurrentTracker(final AtomicInteger c) { + return new TestCallback( + this.expiresInSeconds, + this.delayInMilliseconds, + c, + this.testListener, + this.pathSupplier); + } + + public TestCallback setEventListener(final TestListener testListener) { + return new TestCallback( + this.expiresInSeconds, + this.delayInMilliseconds, + this.concurrentTracker, + testListener, + this.pathSupplier); + } + + public TestCallback setPathSupplier(final Supplier pathSupplier) { + return new TestCallback( + this.expiresInSeconds, + this.delayInMilliseconds, + this.concurrentTracker, + this.testListener, + pathSupplier); + } + + public TestCallback setExpired() { + return this.setExpiresInSeconds(60); + } + } + + public TestCallback createCallback() { + return new TestCallback(); + } +} From 6d4b3ae39bf76b246b896e75ebbf59474644621a Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 29 May 2023 11:47:45 -0600 Subject: [PATCH 04/19] PR fixes: naming, re-entrancy --- .../src/main/com/mongodb/MongoCredential.java | 15 ++--- .../connection/OidcAuthenticator.java | 58 ++++++++----------- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index ff3d4cc168d..a8ea5b71faa 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -717,21 +717,21 @@ public static final class IdpResponse { private final String accessToken; @Nullable - private final Integer expiresInSeconds; + private final Integer accessTokenExpiresInSeconds; @Nullable private final String refreshToken; /** * @param accessToken The OIDC access token - * @param expiresInSeconds The expiration in seconds. If null, the access token is single-use. + * @param accessTokenExpiresInSeconds The expiration in seconds. If null, the access token is single-use. * @param refreshToken The refresh token. If null, refresh will not be attempted. */ - public IdpResponse(final String accessToken, @Nullable final Integer expiresInSeconds, + public IdpResponse(final String accessToken, @Nullable final Integer accessTokenExpiresInSeconds, @Nullable final String refreshToken) { notNull("accessToken", accessToken); this.accessToken = accessToken; - this.expiresInSeconds = expiresInSeconds; + this.accessTokenExpiresInSeconds = accessTokenExpiresInSeconds; this.refreshToken = refreshToken; } @@ -743,11 +743,12 @@ public String getAccessToken() { } /** - * @return The expiration time in seconds. If null, the access token is single-use. + * @return The expiration time for the access token in seconds. + * If null, the access token is single-use. */ @Nullable - public Integer getExpiresInSeconds() { - return expiresInSeconds; + public Integer getAccessTokenExpiresInSeconds() { + return accessTokenExpiresInSeconds; } /** diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 02b78237075..410876cde44 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -29,6 +29,7 @@ import com.mongodb.ServerApi; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; +import com.mongodb.internal.Timeout; import com.mongodb.lang.Nullable; import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; @@ -45,12 +46,11 @@ import java.nio.file.Files; import java.nio.file.Paths; import java.time.Duration; -import java.time.Instant; -import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.stream.Collectors; @@ -195,6 +195,8 @@ public void authenticate(final InternalConnection connection, final ConnectionDe } catch (MongoSecurityException e) { if (triggersRetry(e)) { // TODO-OIDC-x unclear how to provide test coverage for this authLock(connection, connectionDescription); + } else { + throw e; } } } else { @@ -314,31 +316,16 @@ private void prepareRetry(final MongoException e, final OidcCacheEntry cacheEntr @Nullable private String getValidCachedAccessToken() { - MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); - OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); - String cachedAccessToken = cacheEntry.accessToken; - if (cachedAccessToken == null) { - return null; - } - if (cacheEntry.isExpired()) { - return mongoCredentialWithCache.withOidcLock(() -> { - OidcCacheEntry mostRecentCacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); - if (mostRecentCacheEntry.isExpired()) { - mongoCredentialWithCache.setOidcCacheEntry(mostRecentCacheEntry.clearAccessToken()); - return null; - } else { - return mostRecentCacheEntry.accessToken; - } - }); - } - return cachedAccessToken; + return getMongoCredentialWithCache() + .getOidcCacheEntry() + .getValidCachedAccessToken(); } static final class OidcCacheEntry { @Nullable private final String accessToken; @Nullable - private final Instant expiry; + private final Timeout accessTokenExpiry; @Nullable private final String refreshToken; @Nullable @@ -348,22 +335,23 @@ static final class OidcCacheEntry { public String toString() { return "OidcCacheEntry{" + "\n accessToken#hashCode='" + (accessToken == null ? null : accessToken.hashCode()) + '\'' - + ",\n expiry=" + expiry + + ",\n expiry=" + accessTokenExpiry + ",\n refreshToken='" + refreshToken + '\'' + ",\n idpInfo=" + idpInfo + '}'; } OidcCacheEntry(@Nullable final IdpInfo idpInfo, final IdpResponse idpResponse) { - Integer expiresInSeconds = idpResponse.getExpiresInSeconds(); - if (expiresInSeconds != null) { - final Instant expiry = Instant.now().plusSeconds(expiresInSeconds) - .minus(5, ChronoUnit.MINUTES); + Integer accessTokenExpiresInSeconds = idpResponse.getAccessTokenExpiresInSeconds(); + if (accessTokenExpiresInSeconds != null) { this.accessToken = idpResponse.getAccessToken(); - this.expiry = expiry; + long accessTokenExpiryReservedSeconds = TimeUnit.MINUTES.toSeconds(5); + this.accessTokenExpiry = Timeout.startNow( + Math.max(0, accessTokenExpiresInSeconds - accessTokenExpiryReservedSeconds), + TimeUnit.SECONDS); } else { this.accessToken = null; - this.expiry = null; + this.accessTokenExpiry = null; } String refreshToken = idpResponse.getRefreshToken(); if (refreshToken != null) { @@ -379,16 +367,20 @@ public String toString() { this(null, null, null, null); } - private OidcCacheEntry(@Nullable final String accessToken, @Nullable final Instant expiry, + private OidcCacheEntry(@Nullable final String accessToken, @Nullable final Timeout accessTokenExpiry, @Nullable final String refreshToken, @Nullable final IdpInfo idpInfo) { this.accessToken = accessToken; - this.expiry = expiry; + this.accessTokenExpiry = accessTokenExpiry; this.refreshToken = refreshToken; this.idpInfo = idpInfo; } - public boolean isExpired() { - return expiry == null || Instant.now().isAfter(expiry); + @Nullable + public String getValidCachedAccessToken() { + if (accessToken == null || accessTokenExpiry == null || accessTokenExpiry.expired()) { + return null; + } + return accessToken; } public OidcCacheEntry clearAccessToken() { @@ -402,7 +394,7 @@ public OidcCacheEntry clearAccessToken() { public OidcCacheEntry clearRefreshToken() { return new OidcCacheEntry( this.accessToken, - this.expiry, + this.accessTokenExpiry, null, null); } From 30c43275ce1462ceed6e0b1638011d7051848e65 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Tue, 30 May 2023 07:35:12 -0600 Subject: [PATCH 05/19] PR fixes --- .../src/main/com/mongodb/MongoCredential.java | 12 ++- .../src/main/com/mongodb/internal/Locks.java | 10 ++ .../connection/MongoCredentialWithCache.java | 10 +- .../connection/OidcAuthenticator.java | 99 ++++++++++++------- .../connection/SaslAuthenticator.java | 6 +- 5 files changed, 88 insertions(+), 49 deletions(-) diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index a8ea5b71faa..46a81d3fcc5 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -17,6 +17,7 @@ package com.mongodb; import com.mongodb.annotations.Beta; +import com.mongodb.annotations.Evolving; import com.mongodb.annotations.Immutable; import com.mongodb.lang.Nullable; @@ -186,8 +187,8 @@ public final class MongoCredential { /** * The provider name. The value must be a string. *

- * If this is provided, neither - * {@link MongoCredential#REQUEST_TOKEN_CALLBACK_KEY} nor + * If this is provided, + * {@link MongoCredential#REQUEST_TOKEN_CALLBACK_KEY} and * {@link MongoCredential#REFRESH_TOKEN_CALLBACK_KEY} * must not be provided. * @@ -224,7 +225,7 @@ public final class MongoCredential { public static final String REFRESH_TOKEN_CALLBACK_KEY = "REFRESH_TOKEN_CALLBACK"; /** - * Mechanism key for a list of allowed hostnames or ip-addresses (ignoring ports) for MongoDB connections. + * Mechanism key for a list of allowed hostnames or ip-addresses for MongoDB connections. Ports must be excluded. * The hostnames may include a leading "*." wildcard, which allows for matching (potentially nested) subdomains. * When MONGODB-OIDC authentication is attempted against a hostname that does not match any of list of allowed hosts * the driver will raise an error. The type of the value must be {@code List}. @@ -238,6 +239,8 @@ public final class MongoCredential { /** * The list of allowed hosts that will be used if no * {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied. + * The default allowed hosts are: + * {@code "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")} * * @see #createOidcCredential(String) * @since 4.10 @@ -634,6 +637,7 @@ public String toString() { /** * The context for the {@link OidcRequestCallback#onRequest(OidcRequestContext) OIDC request callback}. */ + @Evolving public interface OidcRequestContext { /** * @return The OIDC Identity Provider's configuration that can be used to acquire an Access Token. @@ -649,6 +653,7 @@ public interface OidcRequestContext { /** * The context for the {@link OidcRefreshCallback#onRefresh(OidcRefreshContext) OIDC refresh callback}. */ + @Evolving public interface OidcRefreshContext extends OidcRequestContext { /** * @return The OIDC Refresh token supplied by a prior callback invocation. @@ -690,6 +695,7 @@ public interface OidcRefreshCallback { /** * The OIDC Identity Provider's configuration that can be used to acquire an Access Token. */ + @Evolving public interface IdpInfo { /** * @return URL which describes the Authorization Server. This identifier is the diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index 25fbe199966..6b619b5f34d 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -19,6 +19,7 @@ import com.mongodb.MongoInterruptedException; import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.StampedLock; import java.util.function.Supplier; /** @@ -32,6 +33,15 @@ public static void withLock(final Lock lock, final Runnable action) { }); } + public static V withLock(final StampedLock lock, final Supplier supplier) { + long stamp = lock.writeLock(); + try { + return supplier.get(); + } finally { + lock.unlockWrite(stamp); + } + } + public static V withLock(final Lock lock, final Supplier supplier) { return checkedWithLock(lock, supplier::get); } diff --git a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java index 50579c2b28d..d6776018e18 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java +++ b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java @@ -18,15 +18,13 @@ import com.mongodb.AuthenticationMechanism; import com.mongodb.MongoCredential; -import com.mongodb.internal.Locks; import com.mongodb.lang.Nullable; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; -import java.util.function.Supplier; +import java.util.concurrent.locks.StampedLock; import static com.mongodb.internal.Locks.withLock; - import static com.mongodb.internal.connection.OidcAuthenticator.OidcCacheEntry; /** @@ -76,8 +74,8 @@ public void setOidcCacheEntry(final OidcCacheEntry oidcCacheEntry) { this.cache.oidcCacheEntry = oidcCacheEntry; } - public V withOidcLock(final Supplier k) { - return Locks.withLock(cache.oidcLock, k); + public StampedLock getOidcLock() { + return cache.oidcLock; } public Lock getLock() { @@ -93,7 +91,7 @@ static class Cache { private Object cacheValue; - private final ReentrantLock oidcLock = new ReentrantLock(); + private final StampedLock oidcLock = new StampedLock(); private volatile OidcCacheEntry oidcCacheEntry = new OidcCacheEntry(); Object get(final Object key) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 410876cde44..d8fbeb80fae 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -29,6 +29,7 @@ import com.mongodb.ServerApi; import com.mongodb.connection.ClusterConnectionMode; import com.mongodb.connection.ConnectionDescription; +import com.mongodb.internal.Locks; import com.mongodb.internal.Timeout; import com.mongodb.lang.Nullable; import org.bson.BsonBinaryWriter; @@ -50,6 +51,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.stream.Collectors; @@ -66,6 +68,7 @@ import static com.mongodb.MongoCredential.OidcRequestCallback; import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.notNull; import static java.lang.String.format; @@ -115,17 +118,14 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { @Nullable public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) { try { - OidcRequestCallback requestCallback = getRequestCallback(); - if (requestCallback == null) { + if (isAutomaticAuthentication()) { return wrapInSpeculative(prepareAwsTokenFromFile()); } String cachedAccessToken = getValidCachedAccessToken(); MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); if (cachedAccessToken != null) { - connectionLastAccessToken = cachedAccessToken; - fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; return wrapInSpeculative(prepareTokenAsJwt(cachedAccessToken)); - } else if (mongoCredentialWithCache.getOidcCacheEntry().idpInfo == null) { + } else if (mongoCredentialWithCache.getOidcCacheEntry().getIdpInfo() == null) { String userName = mongoCredentialWithCache.getCredential().getUserName(); return wrapInSpeculative(prepareUsername(userName)); } else { @@ -179,13 +179,15 @@ private OidcRequestCallback getRequestCallback() { @Override public void reauthenticate(final InternalConnection connection) { + // method must only be called after original handshake: + assertTrue(connection.opened()); fallbackState = FallbackState.INITIAL; authLock(connection, connection.getDescription()); } @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { - // method must only be called during original handshake; fail otherwise + // method must only be called during original handshake: assertFalse(connection.opened()); // this method "wraps" the default authentication method in custom OIDC retry logic String accessToken = getValidCachedAccessToken(); @@ -216,7 +218,6 @@ private static boolean triggersRetry(@Nullable final Throwable t) { return false; } - private void authenticateUsing( final InternalConnection connection, final ConnectionDescription connectionDescription, @@ -226,17 +227,14 @@ private void authenticateUsing( } private void authLock(final InternalConnection connection, final ConnectionDescription connectionDescription) { - MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); - mongoCredentialWithCache.withOidcLock(() -> { + Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { while (true) { try { authenticateUsing(connection, connectionDescription, (challenge) -> evaluate(challenge)); break; } catch (MongoSecurityException e) { - OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); - if (triggersRetry(e)) { - prepareRetry(e, cacheEntry); - } else { + boolean shouldRetry = triggersRetry(e) && shouldRetryHandler(); + if (!shouldRetry) { throw e; } } @@ -246,17 +244,17 @@ private void authLock(final InternalConnection connection, final ConnectionDescr } private byte[] evaluate(final byte[] challenge) { - OidcRequestCallback requestCallback = getRequestCallback(); - if (requestCallback == null) { + if (isAutomaticAuthentication()) { return prepareAwsTokenFromFile(); } + OidcRequestCallback requestCallback = assertNotNull(getRequestCallback()); MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); String cachedAccessToken = getValidCachedAccessToken(); String invalidConnectionAccessToken = connectionLastAccessToken; - String cachedRefreshToken = cacheEntry.refreshToken; - IdpInfo cachedIdpInfo = cacheEntry.idpInfo; + String cachedRefreshToken = cacheEntry.getRefreshToken(); + IdpInfo cachedIdpInfo = cacheEntry.getIdpInfo(); if (cachedAccessToken != null) { boolean cachedTokenIsInvalid = cachedAccessToken.equals(invalidConnectionAccessToken); @@ -269,32 +267,39 @@ private byte[] evaluate(final byte[] challenge) { if (cachedAccessToken != null) { fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; return prepareTokenAsJwt(cachedAccessToken); - } else if (refreshCallback != null && cachedRefreshToken != null && cachedIdpInfo != null) { - fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; + } else if (refreshCallback != null && cachedRefreshToken != null) { + assertTrue(cachedIdpInfo != null); // Invoke Refresh Callback using cached Refresh Token validateAllowedHosts(getMongoCredential()); + fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; IdpResponse result = refreshCallback.onRefresh(new OidcRefreshContextImpl( cachedIdpInfo, cachedRefreshToken, CALLBACK_TIMEOUT)); - return handleCallbackResult(cachedIdpInfo, result); + return populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); } else { // cache is empty - if (fallbackState != FallbackState.PHASE_3A_PRINCIPAL && challenge.length == 0) { + boolean idpInfoNotPresent = challenge.length == 0; + if (fallbackState != FallbackState.PHASE_3A_PRINCIPAL && idpInfoNotPresent) { fallbackState = FallbackState.PHASE_3A_PRINCIPAL; return prepareUsername(mongoCredentialWithCache.getCredential().getUserName()); } else { - fallbackState = FallbackState.PHASE_3B_REQUEST_CALLBACK_TOKEN; IdpInfo idpInfo = toIdpInfo(challenge); IdpResponse result = invokeRequestCallback(requestCallback, idpInfo); - return handleCallbackResult(idpInfo, result); + fallbackState = FallbackState.PHASE_3B_REQUEST_CALLBACK_TOKEN; + return populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); } } } + private boolean isAutomaticAuthentication() { + return getRequestCallback() == null; + } + private boolean clientIsComplete() { return fallbackState != FallbackState.PHASE_3A_PRINCIPAL; } - private void prepareRetry(final MongoException e, final OidcCacheEntry cacheEntry) { + private boolean shouldRetryHandler() { + OidcCacheEntry cacheEntry = getMongoCredentialWithCache().getOidcCacheEntry(); MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { // a cached access token failed @@ -310,8 +315,9 @@ private void prepareRetry(final MongoException e, final OidcCacheEntry cacheEntr mongoCredentialWithCache.setOidcCacheEntry(cacheEntry .clearAccessToken() .clearRefreshToken()); - throw e; + return false; } + return true; } @Nullable @@ -334,14 +340,14 @@ static final class OidcCacheEntry { @Override public String toString() { return "OidcCacheEntry{" - + "\n accessToken#hashCode='" + (accessToken == null ? null : accessToken.hashCode()) + '\'' - + ",\n expiry=" + accessTokenExpiry + + "\n accessToken#hashCode='" + Objects.hashCode(accessToken) + '\'' + + ",\n accessTokenExpiry=" + accessTokenExpiry + ",\n refreshToken='" + refreshToken + '\'' + ",\n idpInfo=" + idpInfo + '}'; } - OidcCacheEntry(@Nullable final IdpInfo idpInfo, final IdpResponse idpResponse) { + OidcCacheEntry(final IdpInfo idpInfo, final IdpResponse idpResponse) { Integer accessTokenExpiresInSeconds = idpResponse.getAccessTokenExpiresInSeconds(); if (accessTokenExpiresInSeconds != null) { this.accessToken = idpResponse.getAccessToken(); @@ -376,14 +382,24 @@ private OidcCacheEntry(@Nullable final String accessToken, @Nullable final Timeo } @Nullable - public String getValidCachedAccessToken() { + String getValidCachedAccessToken() { if (accessToken == null || accessTokenExpiry == null || accessTokenExpiry.expired()) { return null; } return accessToken; } - public OidcCacheEntry clearAccessToken() { + @Nullable + public String getRefreshToken() { + return refreshToken; + } + + @Nullable + public IdpInfo getIdpInfo() { + return idpInfo; + } + + OidcCacheEntry clearAccessToken() { return new OidcCacheEntry( null, null, @@ -391,7 +407,7 @@ public OidcCacheEntry clearAccessToken() { this.idpInfo); } - public OidcCacheEntry clearRefreshToken() { + OidcCacheEntry clearRefreshToken() { return new OidcCacheEntry( this.accessToken, this.accessTokenExpiry, @@ -449,15 +465,15 @@ private byte[] prepareUsername(@Nullable final String username) { return toBson(document); } - private byte[] handleCallbackResult( + private byte[] populateCacheWithCallbackResultAndPrepareJwt( final IdpInfo serverInfo, - @Nullable final IdpResponse tokens) { - if (tokens == null) { + @Nullable final IdpResponse idpResponse) { + if (idpResponse == null) { throw new MongoConfigurationException("Result of callback must not be null"); } - OidcCacheEntry newEntry = new OidcCacheEntry(serverInfo, tokens); + OidcCacheEntry newEntry = new OidcCacheEntry(serverInfo, idpResponse); getMongoCredentialWithCache().setOidcCacheEntry(newEntry); - return prepareTokenAsJwt(tokens.getAccessToken()); + return prepareTokenAsJwt(idpResponse.getAccessToken()); } private IdpInfo toIdpInfo(final byte[] challenge) { @@ -522,6 +538,9 @@ private static byte[] toBson(final BsonDocument document) { return bytes; } + /** + * Contains all validation logic for OIDC in one location + */ public static final class OidcValidator { private OidcValidator() { } @@ -591,10 +610,12 @@ public OidcRequestContextImpl(final IdpInfo idpInfo, final Duration timeout) { this.timeout = timeout; } + @Override public IdpInfo getIdpInfo() { return idpInfo; } + @Override public Duration getTimeout() { return timeout; } @@ -611,6 +632,7 @@ public OidcRefreshContextImpl(final IdpInfo idpInfo, final String refreshToken, this.refreshToken = refreshToken; } + @Override public String getRefreshToken() { return refreshToken; } @@ -630,14 +652,17 @@ public IdpInfoImpl(final String issuer, final String clientId, @Nullable final L : Collections.unmodifiableList(requestScopes); } + @Override public String getIssuer() { return issuer; } + @Override public String getClientId() { return clientId; } + @Override public List getRequestScopes() { return requestScopes; } @@ -646,7 +671,7 @@ public List getRequestScopes() { /** * Represents what was sent in the last request to the MongoDB server. */ - enum FallbackState { + private enum FallbackState { INITIAL, PHASE_1_CACHED_TOKEN, PHASE_2_REFRESH_CALLBACK_TOKEN, diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index d1298be571e..296cb7540e4 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -141,9 +141,9 @@ private BsonDocument getNextSaslResponse(final SaslClient saslClient, final Inte private void getNextSaslResponseAsync(final SaslClient saslClient, final InternalConnection connection, final SingleResultCallback callback) { + SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); try { BsonDocument response = getSpeculativeAuthenticateResponse(); - SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); if (response == null) { byte[] serverResponse = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null); sendSaslStartAsync(serverResponse, connection, (result, t) -> { @@ -341,7 +341,7 @@ private void continueConversation(final BsonDocument result) { protected abstract static class SaslClientImpl implements SaslClient { private final MongoCredential credential; - public SaslClientImpl(final MongoCredential credential) { + protected SaslClientImpl(final MongoCredential credential) { this.credential = credential; } @@ -379,7 +379,7 @@ public final String getMechanismName() { return authMechanism.getMechanismName(); } - protected MongoCredential getCredential() { + protected final MongoCredential getCredential() { return credential; } } From f40e66f9eb40fda9ed90fc1dc198255d858b8081 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Tue, 30 May 2023 08:57:05 -0600 Subject: [PATCH 06/19] PR fixes, remove TODOs --- .../src/main/com/mongodb/internal/Locks.java | 11 +- .../connection/OidcAuthenticator.java | 2 +- .../client/OidcAuthenticationProseTests.java | 112 +++--------------- 3 files changed, 30 insertions(+), 95 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index 6b619b5f34d..0c48ae4598c 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -22,6 +22,8 @@ import java.util.concurrent.locks.StampedLock; import java.util.function.Supplier; +import static com.mongodb.assertions.Assertions.assertFalse; + /** *

This class is not part of the public API and may be removed or changed at any time

*/ @@ -34,7 +36,14 @@ public static void withLock(final Lock lock, final Runnable action) { } public static V withLock(final StampedLock lock, final Supplier supplier) { - long stamp = lock.writeLock(); + long stamp; + try { + assertFalse(lock.isWriteLocked()); // not re-entrant, prevent deadlock + stamp = lock.writeLockInterruptibly(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new MongoInterruptedException("Interrupted waiting for lock", e); + } try { return supplier.get(); } finally { diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index d8fbeb80fae..8d91c914c73 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -195,7 +195,7 @@ public void authenticate(final InternalConnection connection, final ConnectionDe try { authenticateUsing(connection, connectionDescription, (bytes) -> prepareTokenAsJwt(accessToken)); } catch (MongoSecurityException e) { - if (triggersRetry(e)) { // TODO-OIDC-x unclear how to provide test coverage for this + if (triggersRetry(e)) { authLock(connection, connectionDescription); } else { throw e; diff --git a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java index 2bb6fc794f8..a6dcc24d92b 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java @@ -52,8 +52,8 @@ import java.nio.file.Paths; import java.time.Duration; import java.util.Arrays; -import java.util.Date; import java.util.List; +import java.util.Random; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; @@ -63,11 +63,11 @@ import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; import static com.mongodb.MongoCredential.IdpInfo; import static com.mongodb.MongoCredential.OidcRefreshContext; +import static com.mongodb.MongoCredential.OidcRequestCallback; import static com.mongodb.MongoCredential.OidcRequestContext; import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; -import static com.mongodb.MongoCredential.OidcRequestCallback; import static com.mongodb.MongoCredential.createOidcCredential; import static com.mongodb.client.TestHelper.setEnvironmentVariable; import static java.lang.System.getenv; @@ -77,10 +77,15 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeTrue; import static util.ThreadTestHelpers.executeAll; public class OidcAuthenticationProseTests { + public static boolean oidcTestsEnabled() { + return "true".equals(getenv().get("OIDC_TESTS_ENABLED")); + } + private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; public static final String TOKEN_DIRECTORY = "/tmp/tokens/"; // TODO-OIDC @@ -88,6 +93,7 @@ public class OidcAuthenticationProseTests { protected static final String OIDC_URL = "mongodb://localhost/?authMechanism=MONGODB-OIDC"; private static final String AWS_OIDC_URL = "mongodb://localhost/?authMechanism=MONGODB-OIDC&authMechanismProperties=PROVIDER_NAME:aws"; + private String appName; protected MongoClient createMongoClient(final MongoClientSettings settings) { return MongoClients.create(settings); @@ -99,9 +105,11 @@ protected void setOidcFile(final String file) { @BeforeEach public void beforeEach() { + assumeTrue(oidcTestsEnabled()); + // In each test, clearing the cache is not required, since there is no global cache setOidcFile("test_user1"); InternalStreamConnection.setRecordEverything(true); - // In each test, clearing the cache is not required, since there is no global cache + this.appName = this.getClass().getSimpleName() + "-" + new Random().nextInt(Integer.MAX_VALUE); } @AfterEach @@ -136,8 +144,6 @@ public void test1CallbackDrivenAuth(final String name, final String file, final } } - // TODO-OIDC-x additional tests for token with null expiry? - @ParameterizedTest @CsvSource(delimiter = '#', value = { // 1.6, both variants: @@ -208,6 +214,7 @@ public void test2AwsAutomaticAuth(final String name, final String file, final St MongoCredential credential = createOidcCredential(null) .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); MongoClientSettings clientSettings = MongoClientSettings.builder() + .applicationName(appName) .credential(credential) .applyConnectionString(new ConnectionString(url)) .build(); @@ -238,14 +245,14 @@ public void test3p1ValidCallbacks() { OidcRequestCallback onRequest2 = (context) -> { assertEquals(expectedClientId, context.getIdpInfo().getClientId()); assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); - //assertEquals(Arrays.asList(""), serverInfo.getRequestScopes()); // TODO-OIDC-x fix when docker updated + assertEquals(Arrays.asList(), context.getIdpInfo().getRequestScopes()); assertEquals(expectedSeconds, context.getTimeout()); return onRequest.onRequest(context); }; OidcRefreshCallback onRefresh2 = (context) -> { assertEquals(expectedClientId, context.getIdpInfo().getClientId()); assertEquals(expectedIssuer, context.getIdpInfo().getIssuer()); - //assertEquals(Arrays.asList(""), serverInfo.getRequestScopes()); // TODO-OIDC-x fix when docker updated + assertEquals(Arrays.asList(), context.getIdpInfo().getRequestScopes()); assertEquals(expectedSeconds, context.getTimeout()); assertEquals("refreshToken", context.getRefreshToken()); return onRefresh.onRefresh(context); @@ -513,6 +520,7 @@ public void test4p5AwsAutomaticWorkflowDoesNotUseCache() { .withMechanismProperty(PROVIDER_NAME_KEY, "aws"); ConnectionString connectionString = new ConnectionString(AWS_OIDC_URL); MongoClientSettings clientSettings = MongoClientSettings.builder() + .applicationName(appName) .credential(credential) .applyConnectionString(connectionString) .build(); @@ -629,85 +637,6 @@ public void test6p1ReauthenticationSucceeds() { } } - //@Test // TODO-OIDC-x ignore this. not a prose test; will need to be updated after oidc spec changes - public void testFullReath() { - ConcurrentLinkedQueue tokens = tokenQueue( - "test_user1", // read during initial population - "test_user1_1", // read when first thread clears cache - "test_user1_expires", // read during onRefresh - "test_user1_2", // read during onRequest - "invalid"); - TestListener events = new TestListener() { - public void add(final String s) { - String message = new Date() - + " -- " + Thread.currentThread().getName() - + " -- " + s; - super.add(message); - } - }; - TestCallback onRequest = new TestCallback() - .setPathSupplier(() -> tokens.remove()) - .setEventListener(events); - TestCallback onRefresh = new TestCallback() - .setPathSupplier(() -> tokens.remove()) - .setEventListener(events); - TestCommandListener commandListener = new TestCommandListener(); - commandListener.setEventStrings(events); - - MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); - try (MongoClient mongoClient = createMongoClient(clientSettings)) { - // Populate the cache, authenticate both connections - delayNextFind(); - executeAll(2, () -> performFind(mongoClient)); - assertEquals(1, onRequest.getInvocations()); - assertEquals(0, onRefresh.getInvocations()); - - events.clear(); - - // Now we need a thread to arrive at AUTHLOCK after a failed find, - // but the cache must contain a new credential. - // The first thread performs a failing-find that takes a long time. - // Then, the second thread starts, and immediately fails its find, - // and passes through AUTHLOCK to populate the cache. - executeAll( - () -> { - failCommand(391, 1, "find"); - performFind(mongoClient); - }, - () -> { - sleep(500); // TODO-OIDC-x less time? - //events.clear(); - events.add("retrying task started"); - failCommand(391, 1, "find"); - performFind(mongoClient); - events.add("retrying task finished"); - }); - - //events.getEventStrings().forEach(e -> System.out.println(" \"" + e + "\",")); - assertEquals(Arrays.asList( - "retrying task started", - "find started", - "find failed", - // entered 391 retry logic - "onRefresh invoked", - "read access token: test_user1_expires", - "saslStart started", - "saslStart failed", - // - "saslStart started", - "saslStart succeeded", - "onRequest invoked", - "read access token: test_user1_2", - "saslContinue started", - "saslContinue succeeded", - "find started", - "find succeeded", - "retrying task finished" - ), events.getEventStrings()); - - } - } - @NotNull private ConcurrentLinkedQueue tokenQueue(final String... queue) { return Stream @@ -733,11 +662,7 @@ public void test6p2ReauthenticationRetriesAndSucceedsWithCache() { } // 6.3 Retries and Fails with no Cache - // TODO-OIDC-x appears to be untestable, since it requires 391 failure on jwt; awaiting spec changes -// @Test -// public void test6p3RetriesAndFailsWithNoCache() { -// fail(); -// } + // Appears to be untestable, since it requires 391 failure on jwt (may be fixed in future spec) @Test public void test6p4SeparateConnectionsAvoidExtraCallbackCalls() { @@ -783,6 +708,7 @@ private MongoClientSettings createSettings( .withMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, onRefresh) .withMechanismProperty(ALLOWED_HOSTS_KEY, allowedHosts); MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applicationName(appName) .applyConnectionString(cs) .credential(credential); if (commandListener != null) { @@ -839,6 +765,7 @@ protected void delayNextFind() { BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(1))) .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) .append("failCommands", new BsonArray(asList(new BsonString("find")))) .append("blockConnection", new BsonBoolean(true)) .append("blockTimeMS", new BsonInt32(100))); @@ -853,12 +780,11 @@ protected void failCommand(final int code, final int times, final String... comm BsonDocument failPointDocument = new BsonDocument("configureFailPoint", new BsonString("failCommand")) .append("mode", new BsonDocument("times", new BsonInt32(times))) .append("data", new BsonDocument() + .append("appName", new BsonString(appName)) .append("failCommands", new BsonArray(list)) .append("errorCode", new BsonInt32(code))); mongoClient.getDatabase("admin").runCommand(failPointDocument); } - // TODO-OIDC-x the driver MUST either use a unique appName or explicitly remove the failCommand after the test to prevent leakage. - // .append("appName", new BsonString(appName)) } public static class TestCallback implements OidcRequestCallback, OidcRefreshCallback { From 02eeceff5fc0603dbf6cf6c01185e151ef5400b6 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Tue, 30 May 2023 10:21:27 -0600 Subject: [PATCH 07/19] PR fix: cannot check re-entry --- driver-core/src/main/com/mongodb/internal/Locks.java | 3 --- .../com/mongodb/client/OidcAuthenticationProseTests.java | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index 0c48ae4598c..c06eddcc6dd 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -22,8 +22,6 @@ import java.util.concurrent.locks.StampedLock; import java.util.function.Supplier; -import static com.mongodb.assertions.Assertions.assertFalse; - /** *

This class is not part of the public API and may be removed or changed at any time

*/ @@ -38,7 +36,6 @@ public static void withLock(final Lock lock, final Runnable action) { public static V withLock(final StampedLock lock, final Supplier supplier) { long stamp; try { - assertFalse(lock.isWriteLocked()); // not re-entrant, prevent deadlock stamp = lock.writeLockInterruptibly(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); diff --git a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java index a6dcc24d92b..17656d87975 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java @@ -655,7 +655,7 @@ public void test6p2ReauthenticationRetriesAndSucceedsWithCache() { // #. Perform a find operation that succeeds. performFind(mongoClient); // #. Force a reauthenication using a failCommand - failCommand(391, 2, "find", "saslStart"); + failCommand(391, 1, "find"); // #. Perform a find operation that succeeds. performFind(mongoClient); } From 1f03603a14a4d386abb619e790ae8aae1f126ea4 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 13:11:21 -0600 Subject: [PATCH 08/19] Apply suggestions from code review Co-authored-by: Valentin Kovalenko --- .../internal/connection/MongoCredentialWithCache.java | 3 +-- .../internal/connection/TestCommandListener.java | 11 +++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java index d6776018e18..19c123ba810 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java +++ b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java @@ -35,8 +35,7 @@ public class MongoCredentialWithCache { private final Cache cache; public MongoCredentialWithCache(final MongoCredential credential) { - this.credential = credential; - this.cache = new Cache(); + this(credential, new Cache()); } private MongoCredentialWithCache(final MongoCredential credential, final Cache cache) { diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java index c520b619fa4..1ff127e88bd 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java @@ -112,8 +112,9 @@ public void reset() { lock.lock(); try { events.clear(); - if (listener != null) { - listener.add("CommandListener reset"); + TestListener observedListener = listener; + if (observedListener != null) { + observedListener.clear(); } } finally { lock.unlock(); @@ -131,13 +132,15 @@ public List getEvents() { private void addEvent(final CommandEvent c) { events.add(c); - if (listener != null) { + TestListener observedListener = listener; + if (observedListener != null) { String className = c.getClass().getSimpleName() .replace("Command", "") .replace("Event", "") .toLowerCase(); // example: "saslContinue succeeded" - listener.add(c.getCommandName() + " " + className); + observedListener.add(c.getCommandName() + " " + className); + } } } From c6c8e14f5d6b362291298733bc4873d4c6e0e909 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 13:28:27 -0600 Subject: [PATCH 09/19] Apply suggestions from code review Co-authored-by: Valentin Kovalenko --- .../connection/OidcAuthenticator.java | 47 ++++++++++--------- .../client/OidcAuthenticationProseTests.java | 2 +- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 8d91c914c73..8d451ebe714 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -83,14 +83,18 @@ public class OidcAuthenticator extends SaslAuthenticator { private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + @Nullable private ServerAddress serverAddress; + @Nullable private String connectionLastAccessToken; private FallbackState fallbackState = FallbackState.INITIAL; + @Nullable private BsonDocument speculativeAuthenticateResponse; + @Nullable private Function evaluateChallengeFunction; public OidcAuthenticator(final MongoCredentialWithCache credential, @@ -193,7 +197,7 @@ public void authenticate(final InternalConnection connection, final ConnectionDe String accessToken = getValidCachedAccessToken(); if (accessToken != null) { try { - authenticateUsing(connection, connectionDescription, (bytes) -> prepareTokenAsJwt(accessToken)); + authenticateUsing(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); } catch (MongoSecurityException e) { if (triggersRetry(e)) { authLock(connection, connectionDescription); @@ -268,7 +272,7 @@ private byte[] evaluate(final byte[] challenge) { fallbackState = FallbackState.PHASE_1_CACHED_TOKEN; return prepareTokenAsJwt(cachedAccessToken); } else if (refreshCallback != null && cachedRefreshToken != null) { - assertTrue(cachedIdpInfo != null); + assertNotNull(cachedIdpInfo); // Invoke Refresh Callback using cached Refresh Token validateAllowedHosts(getMongoCredential()); fallbackState = FallbackState.PHASE_2_REFRESH_CALLBACK_TOKEN; @@ -299,8 +303,8 @@ private boolean clientIsComplete() { } private boolean shouldRetryHandler() { - OidcCacheEntry cacheEntry = getMongoCredentialWithCache().getOidcCacheEntry(); MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + OidcCacheEntry cacheEntry = mongoCredentialWithCache.getOidcCacheEntry(); if (fallbackState == FallbackState.PHASE_1_CACHED_TOKEN) { // a cached access token failed mongoCredentialWithCache.setOidcCacheEntry(cacheEntry @@ -457,7 +461,7 @@ private static String readAwsTokenFromFile() { } } - private byte[] prepareUsername(@Nullable final String username) { + private static byte[] prepareUsername(@Nullable final String username) { BsonDocument document = new BsonDocument(); if (username != null) { document = document.append("n", new BsonString(username)); @@ -476,7 +480,7 @@ private byte[] populateCacheWithCallbackResultAndPrepareJwt( return prepareTokenAsJwt(idpResponse.getAccessToken()); } - private IdpInfo toIdpInfo(final byte[] challenge) { + private static IdpInfo toIdpInfo(final byte[] challenge) { BsonDocument c = new RawBsonDocument(challenge); String issuer = c.getString("issuer").getValue(); String clientId = c.getString("clientId").getValue(); @@ -513,16 +517,16 @@ private void validateAllowedHosts(final MongoCredential credential) { } @Nullable - private List getStringArray(final BsonDocument document, final String key) { - if (!document.containsKey(key) || document.isArray(key)) { + private static List getStringArray(final BsonDocument document, final String key) { + if (!document.isArray(key)) { return null; } - List result = document.getArray(key).getValues().stream() + List result = document.getArray(key).stream() // ignore non-string values from server, rather than error .filter(v -> v.isString()) .map(v -> v.asString().getValue()) .collect(Collectors.toList()); - return Collections.unmodifiableList(result); + return result; } private byte[] prepareTokenAsJwt(final String accessToken) { @@ -599,15 +603,13 @@ public static void validateBeforeUse(final MongoCredential credential) { } - public static class OidcRequestContextImpl implements OidcRequestContext { + private static class OidcRequestContextImpl implements OidcRequestContext { private final IdpInfo idpInfo; private final Duration timeout; - public OidcRequestContextImpl(final IdpInfo idpInfo, final Duration timeout) { - notNull("idpInfo", idpInfo); - notNull("timeout", timeout); - this.idpInfo = idpInfo; - this.timeout = timeout; + OidcRequestContextImpl(final IdpInfo idpInfo, final Duration timeout) { + this.idpInfo = assertNotNull(idpInfo); + this.timeout = assertNotNull(timeout); } @Override @@ -621,15 +623,14 @@ public Duration getTimeout() { } } - public static final class OidcRefreshContextImpl extends OidcRequestContextImpl + private static final class OidcRefreshContextImpl extends OidcRequestContextImpl implements OidcRefreshContext { private final String refreshToken; - public OidcRefreshContextImpl(final IdpInfo idpInfo, final String refreshToken, + OidcRefreshContextImpl(final IdpInfo idpInfo, final String refreshToken, final Duration timeout) { super(idpInfo, timeout); - notNull("refreshToken", refreshToken); - this.refreshToken = refreshToken; + this.refreshToken = assertNotNull(refreshToken); } @Override @@ -638,15 +639,15 @@ public String getRefreshToken() { } } - public static final class IdpInfoImpl implements IdpInfo { + private static final class IdpInfoImpl implements IdpInfo { private final String issuer; private final String clientId; private final List requestScopes; - public IdpInfoImpl(final String issuer, final String clientId, @Nullable final List requestScopes) { - this.issuer = issuer; - this.clientId = clientId; + IdpInfoImpl(final String issuer, final String clientId, @Nullable final List requestScopes) { + this.issuer = assertNotNull(issuer); + this.clientId = assertNotNull(clientId); this.requestScopes = requestScopes == null ? Collections.emptyList() : Collections.unmodifiableList(requestScopes); diff --git a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java index 17656d87975..b335daa259d 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java @@ -83,7 +83,7 @@ public class OidcAuthenticationProseTests { public static boolean oidcTestsEnabled() { - return "true".equals(getenv().get("OIDC_TESTS_ENABLED")); + return Boolean.parseBoolean(getenv().get("OIDC_TESTS_ENABLED")); } private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; From de08e34cfc0177c9233822e72eb8ea50b7cfcad1 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 13:14:27 -0600 Subject: [PATCH 10/19] temp --- driver-core/src/main/com/mongodb/MongoCredential.java | 2 +- .../internal/connection/MongoCredentialWithCache.java | 6 +++--- .../com/mongodb/internal/connection/OidcAuthenticator.java | 4 ++-- .../test/functional/com/mongodb/client/TestListener.java | 5 ++++- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/driver-core/src/main/com/mongodb/MongoCredential.java b/driver-core/src/main/com/mongodb/MongoCredential.java index 46a81d3fcc5..418863dc21c 100644 --- a/driver-core/src/main/com/mongodb/MongoCredential.java +++ b/driver-core/src/main/com/mongodb/MongoCredential.java @@ -240,7 +240,7 @@ public final class MongoCredential { * The list of allowed hosts that will be used if no * {@link MongoCredential#ALLOWED_HOSTS_KEY} value is supplied. * The default allowed hosts are: - * {@code "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1")} + * {@code "*.mongodb.net", "*.mongodb-dev.net", "*.mongodbgov.net", "localhost", "127.0.0.1", "::1"} * * @see #createOidcCredential(String) * @since 4.10 diff --git a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java index 19c123ba810..3d65efe48ee 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java +++ b/driver-core/src/main/com/mongodb/internal/connection/MongoCredentialWithCache.java @@ -65,15 +65,15 @@ public void putInCache(final Object key, final Object value) { cache.set(key, value); } - public OidcCacheEntry getOidcCacheEntry() { + OidcCacheEntry getOidcCacheEntry() { return cache.oidcCacheEntry; } - public void setOidcCacheEntry(final OidcCacheEntry oidcCacheEntry) { + void setOidcCacheEntry(final OidcCacheEntry oidcCacheEntry) { this.cache.oidcCacheEntry = oidcCacheEntry; } - public StampedLock getOidcLock() { + StampedLock getOidcLock() { return cache.oidcLock; } diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 8d451ebe714..72fb960e2f2 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -394,12 +394,12 @@ String getValidCachedAccessToken() { } @Nullable - public String getRefreshToken() { + String getRefreshToken() { return refreshToken; } @Nullable - public IdpInfo getIdpInfo() { + IdpInfo getIdpInfo() { return idpInfo; } diff --git a/driver-core/src/test/functional/com/mongodb/client/TestListener.java b/driver-core/src/test/functional/com/mongodb/client/TestListener.java index e014d98af19..db68065432c 100644 --- a/driver-core/src/test/functional/com/mongodb/client/TestListener.java +++ b/driver-core/src/test/functional/com/mongodb/client/TestListener.java @@ -16,6 +16,8 @@ package com.mongodb.client; +import com.mongodb.annotations.ThreadSafe; + import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -23,7 +25,8 @@ /** * A simple listener that consumes string events, which can be checked in tests. */ -public class TestListener { +@ThreadSafe +public final class TestListener { private final List events = Collections.synchronizedList(new ArrayList<>()); public void add(final String s) { From 198b7f54dd33a105cef9d68b094ec2f8d4e9e16b Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 13:20:48 -0600 Subject: [PATCH 11/19] PR FIxes --- .../connection/TestCommandListener.java | 55 +++++++++---------- .../client/OidcAuthenticationProseTests.java | 15 ++--- 2 files changed, 30 insertions(+), 40 deletions(-) diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java index 1ff127e88bd..2dec2f01fe1 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java @@ -57,8 +57,7 @@ public class TestCommandListener implements CommandListener { private final List eventTypes; private final List ignoredCommandMonitoringEvents; private final List events = new ArrayList<>(); - @Nullable - private volatile TestListener listener = null; + private final TestListener listener; private final Lock lock = new ReentrantLock(); private final Condition commandCompletedCondition = lock.newCondition(); private final boolean observeSensitiveCommands; @@ -80,42 +79,42 @@ public Codec get(final Class clazz, final CodecRegistry registry) { }); } + /** + * When a test listener is set, this command listener will send string events to the + * test listener in the form {@code " "}, where the event + * type will be lowercase and will omit the terms "command" and "event". + * For example: {@code "saslContinue succeeded"}. + * + * @see InternalStreamConnection#setRecordEverything(boolean) + * @param listener the test listener + */ + public TestCommandListener(final TestListener listener) { + this(Arrays.asList("commandStartedEvent", "commandSucceededEvent", "commandFailedEvent"), emptyList(), true, listener); + } + public TestCommandListener() { this(Arrays.asList("commandStartedEvent", "commandSucceededEvent", "commandFailedEvent"), emptyList()); } public TestCommandListener(final List eventTypes, final List ignoredCommandMonitoringEvents) { - this(eventTypes, ignoredCommandMonitoringEvents, true); + this(eventTypes, ignoredCommandMonitoringEvents, true, null); } public TestCommandListener(final List eventTypes, final List ignoredCommandMonitoringEvents, - final boolean observeSensitiveCommands) { + final boolean observeSensitiveCommands, final TestListener listener) { this.eventTypes = eventTypes; this.ignoredCommandMonitoringEvents = ignoredCommandMonitoringEvents; this.observeSensitiveCommands = observeSensitiveCommands; + this.listener = listener; } - /** - * When this is set, this command listener will send string events to the - * listener in the form {@code " "}, where the event - * type will be lowercase and will omit the terms "command" and "event". - * For example: {@code "saslContinue succeeded"}. - * - * @see InternalStreamConnection#setRecordEverything(boolean) - * @param eventStrings the test listener - */ - public void setEventStrings(final TestListener eventStrings) { - this.listener = eventStrings; - } + public void reset() { lock.lock(); try { events.clear(); - TestListener observedListener = listener; - if (observedListener != null) { - observedListener.clear(); - } + listener.clear(); } finally { lock.unlock(); } @@ -132,16 +131,12 @@ public List getEvents() { private void addEvent(final CommandEvent c) { events.add(c); - TestListener observedListener = listener; - if (observedListener != null) { - String className = c.getClass().getSimpleName() - .replace("Command", "") - .replace("Event", "") - .toLowerCase(); - // example: "saslContinue succeeded" - observedListener.add(c.getCommandName() + " " + className); - } - } + String className = c.getClass().getSimpleName() + .replace("Command", "") + .replace("Event", "") + .toLowerCase(); + // example: "saslContinue succeeded" + listener.add(c.getCommandName() + " " + className); } public CommandStartedEvent getCommandStartedEvent(final String commandName) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java index b335daa259d..056d251225a 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java @@ -401,8 +401,7 @@ public void test4p4ErrorClearsCache() { .setPathSupplier(() -> tokens.remove()) .setEventListener(listener); - TestCommandListener commandListener = new TestCommandListener(); - commandListener.setEventStrings(listener); + TestCommandListener commandListener = new TestCommandListener(listener); MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); try (MongoClient mongoClient = createMongoClient(clientSettings)) { @@ -483,8 +482,7 @@ public void testEventListenerMustNotLogReauthentication() { .setPathSupplier(() -> tokens.remove()) .setEventListener(listener); - TestCommandListener commandListener = new TestCommandListener(); - commandListener.setEventStrings(listener); + TestCommandListener commandListener = new TestCommandListener(listener); MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); try (MongoClient mongoClient = createMongoClient(clientSettings)) { @@ -539,8 +537,7 @@ public void test5SpeculativeAuthentication() { // #. Create a client with a request callback that returns a valid token that will not expire soon. TestListener listener = new TestListener(); TestCallback onRequest = createCallback().setEventListener(listener); - TestCommandListener commandListener = new TestCommandListener(); - commandListener.setEventStrings(listener); + TestCommandListener commandListener = new TestCommandListener(listener); MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, null, null, commandListener); try (MongoClient mongoClient = createMongoClient(clientSettings)) { // instead of setting failpoints for saslStart, we inspect events @@ -560,8 +557,7 @@ public void test5SpeculativeAuthentication() { @Test public void testAutomaticAuthUsesSpeculative() { TestListener listener = new TestListener(); - TestCommandListener commandListener = new TestCommandListener(); - commandListener.setEventStrings(listener); + TestCommandListener commandListener = new TestCommandListener(listener); MongoClientSettings settings = createSettings( AWS_OIDC_URL, null, null, Arrays.asList(), commandListener); @@ -585,8 +581,7 @@ public void test6p1ReauthenticationSucceeds() { TestCallback onRefresh = createCallback().setEventListener(listener); // #. Create a client with the callbacks and an event listener capable of listening for SASL commands. - TestCommandListener commandListener = new TestCommandListener(); - commandListener.setEventStrings(listener); + TestCommandListener commandListener = new TestCommandListener(listener); MongoClientSettings clientSettings = createSettings(OIDC_URL, onRequest, onRefresh, null, commandListener); try (MongoClient mongoClient = createMongoClient(clientSettings)) { From c1c7a50e771568197f6698fb667b48182c1a71c4 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 14:00:05 -0600 Subject: [PATCH 12/19] PR fixes --- .../internal/connection/AwsAuthenticator.java | 9 -- .../InternalStreamConnectionFactory.java | 2 - .../connection/OidcAuthenticator.java | 108 +++++++++--------- .../connection/SaslAuthenticator.java | 13 +++ .../OidcAuthenticationProseTests.java | 13 ++- 5 files changed, 78 insertions(+), 67 deletions(-) rename driver-sync/src/test/functional/com/mongodb/{client => internal/connection}/OidcAuthenticationProseTests.java (99%) diff --git a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java index c2382bc9ba5..04c10af5385 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java @@ -172,13 +172,4 @@ private AwsCredential createAwsCredential() { return awsCredential; } } - - private static byte[] toBson(final BsonDocument document) { - byte[] bytes; - BasicOutputBuffer buffer = new BasicOutputBuffer(); - new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); - bytes = new byte[buffer.size()]; - System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); - return bytes; - } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java index 9d4dc8aaaca..f879c642ccd 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnectionFactory.java @@ -33,7 +33,6 @@ import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.connection.ClientMetadataHelper.createClientMetadataDocument; -import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; class InternalStreamConnectionFactory implements InternalConnectionFactory { private final ClusterConnectionMode clusterConnectionMode; @@ -108,7 +107,6 @@ private Authenticator createAuthenticator(final MongoCredentialWithCache credent case MONGODB_AWS: return new AwsAuthenticator(credential, clusterConnectionMode, serverApi); case MONGODB_OIDC: - validateBeforeUse(credential.getCredential()); return new OidcAuthenticator(credential, clusterConnectionMode, serverApi); default: throw new IllegalArgumentException("Unsupported authentication mechanism " + authenticationMechanism); diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 72fb960e2f2..68e26f90a29 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -31,14 +31,11 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.Locks; import com.mongodb.internal.Timeout; +import com.mongodb.internal.VisibleForTesting; import com.mongodb.lang.Nullable; -import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; import org.bson.BsonString; import org.bson.RawBsonDocument; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.EncoderContext; -import org.bson.io.BasicOutputBuffer; import org.jetbrains.annotations.NotNull; import javax.security.sasl.SaslClient; @@ -70,6 +67,7 @@ import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; import static java.lang.String.format; /** @@ -100,6 +98,7 @@ public class OidcAuthenticator extends SaslAuthenticator { public OidcAuthenticator(final MongoCredentialWithCache credential, final ClusterConnectionMode clusterConnectionMode, @Nullable final ServerApi serverApi) { super(credential, clusterConnectionMode, serverApi); + validateBeforeUse(credential.getCredential()); if (getMongoCredential().getAuthenticationMechanism() != MONGODB_OIDC) { throw new MongoException("Incorrect mechanism: " + getMongoCredential().getMechanism()); @@ -123,7 +122,7 @@ protected SaslClient createSaslClient(final ServerAddress serverAddress) { public BsonDocument createSpeculativeAuthenticateCommand(final InternalConnection connection) { try { if (isAutomaticAuthentication()) { - return wrapInSpeculative(prepareAwsTokenFromFile()); + return wrapInSpeculative(prepareAwsTokenFromFileAsJwt()); } String cachedAccessToken = getValidCachedAccessToken(); MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); @@ -157,7 +156,6 @@ public BsonDocument getSpeculativeAuthenticateResponse() { this.speculativeAuthenticateResponse = null; if (response == null) { this.connectionLastAccessToken = null; - this.fallbackState = FallbackState.INITIAL; } return response; } @@ -185,7 +183,6 @@ private OidcRequestCallback getRequestCallback() { public void reauthenticate(final InternalConnection connection) { // method must only be called after original handshake: assertTrue(connection.opened()); - fallbackState = FallbackState.INITIAL; authLock(connection, connection.getDescription()); } @@ -231,6 +228,7 @@ private void authenticateUsing( } private void authLock(final InternalConnection connection, final ConnectionDescription connectionDescription) { + fallbackState = FallbackState.INITIAL; Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { while (true) { try { @@ -249,7 +247,7 @@ private void authLock(final InternalConnection connection, final ConnectionDescr private byte[] evaluate(final byte[] challenge) { if (isAutomaticAuthentication()) { - return prepareAwsTokenFromFile(); + return prepareAwsTokenFromFileAsJwt(); } OidcRequestCallback requestCallback = assertNotNull(getRequestCallback()); @@ -281,13 +279,27 @@ private byte[] evaluate(final byte[] challenge) { return populateCacheWithCallbackResultAndPrepareJwt(cachedIdpInfo, result); } else { // cache is empty + + /* + A check for present idp info short-circuits phase-3a. + + If a challenge is present, it can only be a response to a + "principal-request", so the challenge must be the resulting + idp info. Such a request is made during speculative auth, + though the source is unimportant, as long as we detect and + use it here. + + Checking that the fallback state is not phase-3a ensures that + this does not loop infinitely in the case of a bug. + */ boolean idpInfoNotPresent = challenge.length == 0; if (fallbackState != FallbackState.PHASE_3A_PRINCIPAL && idpInfoNotPresent) { fallbackState = FallbackState.PHASE_3A_PRINCIPAL; return prepareUsername(mongoCredentialWithCache.getCredential().getUserName()); } else { IdpInfo idpInfo = toIdpInfo(challenge); - IdpResponse result = invokeRequestCallback(requestCallback, idpInfo); + validateAllowedHosts(getMongoCredential()); + IdpResponse result = requestCallback.onRequest(new OidcRequestContextImpl(idpInfo, CALLBACK_TIMEOUT)); fallbackState = FallbackState.PHASE_3B_REQUEST_CALLBACK_TOKEN; return populateCacheWithCallbackResultAndPrepareJwt(idpInfo, result); } @@ -441,11 +453,6 @@ public byte[] evaluateChallengeInternal(final byte[] challenge) { } } - private static byte[] prepareAwsTokenFromFile() { - return toBson(new BsonDocument() - .append("jwt", new BsonString(readAwsTokenFromFile()))); - } - private static String readAwsTokenFromFile() { String path = System.getenv(AWS_WEB_IDENTITY_TOKEN_FILE); if (path == null) { @@ -490,12 +497,6 @@ private static IdpInfo toIdpInfo(final byte[] challenge) { getStringArray(c, "requestScopes")); } - private IdpResponse invokeRequestCallback(final OidcRequestCallback requestCallback, - final IdpInfo serverInfo) { - validateAllowedHosts(getMongoCredential()); - return requestCallback.onRequest(new OidcRequestContextImpl(serverInfo, CALLBACK_TIMEOUT)); - } - private void validateAllowedHosts(final MongoCredential credential) { List allowedHosts = assertNotNull(credential.getMechanismProperty(ALLOWED_HOSTS_KEY, DEFAULT_ALLOWED_HOSTS)); String host = serverAddress.getHost(); @@ -531,15 +532,16 @@ private static List getStringArray(final BsonDocument document, final St private byte[] prepareTokenAsJwt(final String accessToken) { connectionLastAccessToken = accessToken; - return toBson(new BsonDocument().append("jwt", new BsonString(accessToken))); + return toJwtDocument(accessToken); } - private static byte[] toBson(final BsonDocument document) { - BasicOutputBuffer buffer = new BasicOutputBuffer(); - new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); - byte[] bytes = new byte[buffer.size()]; - System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); - return bytes; + private static byte[] prepareAwsTokenFromFileAsJwt() { + String accessToken = readAwsTokenFromFile(); + return toJwtDocument(accessToken); + } + + private static byte[] toJwtDocument(final String accessToken) { + return toBson(new BsonDocument().append("jwt", new BsonString(accessToken))); } /** @@ -572,38 +574,38 @@ public static void validateCreateOidcCredential(@Nullable final char[] password) } } + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PACKAGE) public static void validateBeforeUse(final MongoCredential credential) { AuthenticationMechanism mechanism = credential.getAuthenticationMechanism(); String userName = credential.getUserName(); - - if (mechanism == AuthenticationMechanism.MONGODB_OIDC) { - Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); - Object requestCallback = credential.getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); - Object refreshCallback = credential.getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); - if (providerName == null) { - // callback - if (requestCallback == null) { - throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + " or " - + REQUEST_TOKEN_CALLBACK_KEY + " must be specified"); - } - } else { - // automatic - if (userName != null) { - throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified"); - } - if (requestCallback != null) { - throw new IllegalArgumentException(REQUEST_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); - } - if (refreshCallback != null) { - throw new IllegalArgumentException(REFRESH_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); - } + assertTrue(mechanism == AuthenticationMechanism.MONGODB_OIDC); + Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); + Object requestCallback = credential.getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); + Object refreshCallback = credential.getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); + if (providerName == null) { + // callback + if (requestCallback == null) { + throw new IllegalArgumentException("Either " + PROVIDER_NAME_KEY + " or " + + REQUEST_TOKEN_CALLBACK_KEY + " must be specified"); + } + } else { + // automatic + if (userName != null) { + throw new IllegalArgumentException("user name must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + if (requestCallback != null) { + throw new IllegalArgumentException(REQUEST_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); + } + if (refreshCallback != null) { + throw new IllegalArgumentException(REFRESH_TOKEN_CALLBACK_KEY + " must not be specified when " + PROVIDER_NAME_KEY + " is specified"); } } } } - private static class OidcRequestContextImpl implements OidcRequestContext { + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static class OidcRequestContextImpl implements OidcRequestContext { private final IdpInfo idpInfo; private final Duration timeout; @@ -623,7 +625,8 @@ public Duration getTimeout() { } } - private static final class OidcRefreshContextImpl extends OidcRequestContextImpl + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static final class OidcRefreshContextImpl extends OidcRequestContextImpl implements OidcRefreshContext { private final String refreshToken; @@ -639,7 +642,8 @@ public String getRefreshToken() { } } - private static final class IdpInfoImpl implements IdpInfo { + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) + static final class IdpInfoImpl implements IdpInfo { private final String issuer; private final String clientId; diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index 296cb7540e4..e399b00bea8 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -32,9 +32,13 @@ import com.mongodb.lang.NonNull; import com.mongodb.lang.Nullable; import org.bson.BsonBinary; +import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; +import org.bson.codecs.BsonDocumentCodec; +import org.bson.codecs.EncoderContext; +import org.bson.io.BasicOutputBuffer; import javax.security.auth.Subject; import javax.security.auth.login.LoginException; @@ -285,6 +289,15 @@ void doAsSubject(final java.security.PrivilegedAction action) { } } + static byte[] toBson(final BsonDocument document) { + byte[] bytes; + BasicOutputBuffer buffer = new BasicOutputBuffer(); + new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); + bytes = new byte[buffer.size()]; + System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize()); + return bytes; + } + private final class Continuator implements SingleResultCallback { private final SaslClient saslClient; private final BsonDocument saslStartDocument; diff --git a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java similarity index 99% rename from driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java rename to driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index 056d251225a..74e95d7e253 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.mongodb.client; +package com.mongodb.internal.connection; import com.mongodb.ConnectionString; import com.mongodb.MongoClientSettings; @@ -24,10 +24,10 @@ import com.mongodb.MongoCredential.IdpResponse; import com.mongodb.MongoCredential.OidcRefreshCallback; import com.mongodb.MongoSecurityException; +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; +import com.mongodb.client.TestListener; import com.mongodb.event.CommandListener; -import com.mongodb.internal.connection.InternalStreamConnection; -import com.mongodb.internal.connection.OidcAuthenticator; -import com.mongodb.internal.connection.TestCommandListener; import com.mongodb.lang.Nullable; import org.bson.BsonArray; import org.bson.BsonBoolean; @@ -80,6 +80,11 @@ import static org.junit.jupiter.api.Assumptions.assumeTrue; import static util.ThreadTestHelpers.executeAll; + +/** + * See + * Prose Tests. + */ public class OidcAuthenticationProseTests { public static boolean oidcTestsEnabled() { From a76c7dc40eeaf6555f0814df1cb3800f90aa66da Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 14:01:31 -0600 Subject: [PATCH 13/19] Implement reauthenticate, for spec tests --- .../main/com/mongodb/internal/connection/Authenticator.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index 96c66affb0c..45e0b078452 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -101,8 +101,7 @@ abstract void authenticateAsync(InternalConnection connection, ConnectionDescrip SingleResultCallback callback); public void reauthenticate(final InternalConnection connection) { - throw new UnsupportedOperationException( - "Reauthentication requested by server but is not supported by specified mechanism."); + authenticate(connection, connection.getDescription()); } } From feab34dc0754ef1140748393d59fcb595f1e683c Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 14:21:23 -0600 Subject: [PATCH 14/19] PR Fixes --- .../mongodb/internal/connection/AwsAuthenticator.java | 4 ---- .../internal/connection/OidcAuthenticator.java | 11 ++++------- .../internal/connection/TestCommandListener.java | 11 ++++++++--- .../com/mongodb/client/unified/Entities.java | 3 ++- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java index 04c10af5385..35f9f8120ee 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/AwsAuthenticator.java @@ -26,14 +26,10 @@ import com.mongodb.internal.authentication.AwsCredentialHelper; import com.mongodb.lang.Nullable; import org.bson.BsonBinary; -import org.bson.BsonBinaryWriter; import org.bson.BsonDocument; import org.bson.BsonInt32; import org.bson.BsonString; import org.bson.RawBsonDocument; -import org.bson.codecs.BsonDocumentCodec; -import org.bson.codecs.EncoderContext; -import org.bson.io.BasicOutputBuffer; import javax.security.sasl.SaslClient; import javax.security.sasl.SaslException; diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 68e26f90a29..519e5d82e1e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -56,17 +56,16 @@ import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; import static com.mongodb.MongoCredential.ALLOWED_HOSTS_KEY; import static com.mongodb.MongoCredential.DEFAULT_ALLOWED_HOSTS; +import static com.mongodb.MongoCredential.OidcRefreshCallback; import static com.mongodb.MongoCredential.OidcRefreshContext; +import static com.mongodb.MongoCredential.OidcRequestCallback; import static com.mongodb.MongoCredential.OidcRequestContext; import static com.mongodb.MongoCredential.PROVIDER_NAME_KEY; import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; -import static com.mongodb.MongoCredential.OidcRefreshCallback; -import static com.mongodb.MongoCredential.OidcRequestCallback; import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; -import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; import static java.lang.String.format; @@ -81,7 +80,6 @@ public class OidcAuthenticator extends SaslAuthenticator { private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; - @Nullable private ServerAddress serverAddress; @Nullable @@ -449,7 +447,7 @@ public boolean isComplete() { } public byte[] evaluateChallengeInternal(final byte[] challenge) { - return evaluateChallengeFunction.apply(challenge); + return assertNotNull(evaluateChallengeFunction).apply(challenge); } } @@ -522,12 +520,11 @@ private static List getStringArray(final BsonDocument document, final St if (!document.isArray(key)) { return null; } - List result = document.getArray(key).stream() + return document.getArray(key).stream() // ignore non-string values from server, rather than error .filter(v -> v.isString()) .map(v -> v.asString().getValue()) .collect(Collectors.toList()); - return result; } private byte[] prepareTokenAsJwt(final String accessToken) { diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java index 2dec2f01fe1..d7bf8529b48 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TestCommandListener.java @@ -57,6 +57,7 @@ public class TestCommandListener implements CommandListener { private final List eventTypes; private final List ignoredCommandMonitoringEvents; private final List events = new ArrayList<>(); + @Nullable private final TestListener listener; private final Lock lock = new ReentrantLock(); private final Condition commandCompletedCondition = lock.newCondition(); @@ -101,7 +102,7 @@ public TestCommandListener(final List eventTypes, final List ign } public TestCommandListener(final List eventTypes, final List ignoredCommandMonitoringEvents, - final boolean observeSensitiveCommands, final TestListener listener) { + final boolean observeSensitiveCommands, @Nullable final TestListener listener) { this.eventTypes = eventTypes; this.ignoredCommandMonitoringEvents = ignoredCommandMonitoringEvents; this.observeSensitiveCommands = observeSensitiveCommands; @@ -114,7 +115,9 @@ public void reset() { lock.lock(); try { events.clear(); - listener.clear(); + if (listener != null) { + listener.clear(); + } } finally { lock.unlock(); } @@ -136,7 +139,9 @@ private void addEvent(final CommandEvent c) { .replace("Event", "") .toLowerCase(); // example: "saslContinue succeeded" - listener.add(c.getCommandName() + " " + className); + if (listener != null) { + listener.add(c.getCommandName() + " " + className); + } } public CommandStartedEvent getCommandStartedEvent(final String commandName) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java index 3bc32e26c26..69174462f77 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/Entities.java @@ -373,7 +373,8 @@ private void initClient(final BsonDocument entity, final String id, TestCommandListener testCommandListener = new TestCommandListener( entity.getArray("observeEvents").stream() .map(type -> type.asString().getValue()).collect(Collectors.toList()), - ignoreCommandMonitoringEvents, entity.getBoolean("observeSensitiveCommands", BsonBoolean.FALSE).getValue()); + ignoreCommandMonitoringEvents, entity.getBoolean("observeSensitiveCommands", BsonBoolean.FALSE).getValue(), + null); clientSettingsBuilder.addCommandListener(testCommandListener); putEntity(id + "-command-listener", testCommandListener, clientCommandListeners); From 537bead12d49cff8126489b92f24a5ad7c08378f Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 14:37:40 -0600 Subject: [PATCH 15/19] PR test fixes --- .../connection/OidcAuthenticator.java | 2 -- .../com/mongodb/AuthConnectionStringTest.java | 26 +++++++++---------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 519e5d82e1e..6ecb9c80c35 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -573,9 +573,7 @@ public static void validateCreateOidcCredential(@Nullable final char[] password) @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PACKAGE) public static void validateBeforeUse(final MongoCredential credential) { - AuthenticationMechanism mechanism = credential.getAuthenticationMechanism(); String userName = credential.getUserName(); - assertTrue(mechanism == AuthenticationMechanism.MONGODB_OIDC); Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); Object requestCallback = credential.getMechanismProperty(REQUEST_TOKEN_CALLBACK_KEY, null); Object refreshCallback = credential.getMechanismProperty(REFRESH_TOKEN_CALLBACK_KEY, null); diff --git a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java index 019b8d15a4b..3bac4afa179 100644 --- a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java +++ b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java @@ -113,20 +113,18 @@ private MongoCredential getMongoCredential() { MongoCredential credential = connectionString.getCredential(); if (credential != null) { BsonArray callbacks = (BsonArray) getExpectedValue("callback"); - if (callbacks != null) { - for (BsonValue v : callbacks) { - String string = ((BsonString) v).getValue(); - if ("oidcRequest".equals(string)) { - credential = credential.withMechanismProperty( - REQUEST_TOKEN_CALLBACK_KEY, - (MongoCredential.OidcRequestCallback) (context) -> null); - } else if ("oidcRefresh".equals(string)) { - credential = credential.withMechanismProperty( - REFRESH_TOKEN_CALLBACK_KEY, - (MongoCredential.OidcRefreshCallback) (context) -> null); - } else { - fail("Unsupported callback: " + string); - } + for (BsonValue v : callbacks) { + String string = ((BsonString) v).getValue(); + if ("oidcRequest".equals(string)) { + credential = credential.withMechanismProperty( + REQUEST_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRequestCallback) (context) -> null); + } else if ("oidcRefresh".equals(string)) { + credential = credential.withMechanismProperty( + REFRESH_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRefreshCallback) (context) -> null); + } else { + fail("Unsupported callback: " + string); } } OidcAuthenticator.OidcValidator.validateBeforeUse(credential); From 9f38271fe61f80db16d77150bc1a1a8cbec3499a Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 14:40:56 -0600 Subject: [PATCH 16/19] PR Fixes --- .../connection/OidcAuthenticator.java | 3 +- .../com/mongodb/AuthConnectionStringTest.java | 31 +++++++++++-------- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 6ecb9c80c35..4f748d35c8e 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -80,6 +80,7 @@ public class OidcAuthenticator extends SaslAuthenticator { private static final String AWS_WEB_IDENTITY_TOKEN_FILE = "AWS_WEB_IDENTITY_TOKEN_FILE"; + @Nullable private ServerAddress serverAddress; @Nullable @@ -497,7 +498,7 @@ private static IdpInfo toIdpInfo(final byte[] challenge) { private void validateAllowedHosts(final MongoCredential credential) { List allowedHosts = assertNotNull(credential.getMechanismProperty(ALLOWED_HOSTS_KEY, DEFAULT_ALLOWED_HOSTS)); - String host = serverAddress.getHost(); + String host = assertNotNull(serverAddress).getHost(); boolean permitted = allowedHosts.stream().anyMatch(allowedHost -> { if (allowedHost.startsWith("*.")) { String ending = allowedHost.substring(1); diff --git a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java index 3bac4afa179..7f4acab857d 100644 --- a/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java +++ b/driver-core/src/test/unit/com/mongodb/AuthConnectionStringTest.java @@ -36,6 +36,7 @@ import java.util.Collection; import java.util.List; +import static com.mongodb.AuthenticationMechanism.MONGODB_OIDC; import static com.mongodb.MongoCredential.REFRESH_TOKEN_CALLBACK_KEY; import static com.mongodb.MongoCredential.REQUEST_TOKEN_CALLBACK_KEY; @@ -113,21 +114,25 @@ private MongoCredential getMongoCredential() { MongoCredential credential = connectionString.getCredential(); if (credential != null) { BsonArray callbacks = (BsonArray) getExpectedValue("callback"); - for (BsonValue v : callbacks) { - String string = ((BsonString) v).getValue(); - if ("oidcRequest".equals(string)) { - credential = credential.withMechanismProperty( - REQUEST_TOKEN_CALLBACK_KEY, - (MongoCredential.OidcRequestCallback) (context) -> null); - } else if ("oidcRefresh".equals(string)) { - credential = credential.withMechanismProperty( - REFRESH_TOKEN_CALLBACK_KEY, - (MongoCredential.OidcRefreshCallback) (context) -> null); - } else { - fail("Unsupported callback: " + string); + if (callbacks != null) { + for (BsonValue v : callbacks) { + String string = ((BsonString) v).getValue(); + if ("oidcRequest".equals(string)) { + credential = credential.withMechanismProperty( + REQUEST_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRequestCallback) (context) -> null); + } else if ("oidcRefresh".equals(string)) { + credential = credential.withMechanismProperty( + REFRESH_TOKEN_CALLBACK_KEY, + (MongoCredential.OidcRefreshCallback) (context) -> null); + } else { + fail("Unsupported callback: " + string); + } } } - OidcAuthenticator.OidcValidator.validateBeforeUse(credential); + if (MONGODB_OIDC.getMechanismName().equals(credential.getMechanism())) { + OidcAuthenticator.OidcValidator.validateBeforeUse(credential); + } } return credential; } From 4d7f3c9a0508978be7dc6ef353e87a6cb8562b3c Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Fri, 2 Jun 2023 14:20:49 -0600 Subject: [PATCH 17/19] PR fixes --- .../com/mongodb/internal/connection/OidcAuthenticator.java | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 4f748d35c8e..9280510c9d9 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -72,7 +72,7 @@ /** *

This class is not part of the public API and may be removed or changed at any time

*/ -public class OidcAuthenticator extends SaslAuthenticator { +public final class OidcAuthenticator extends SaslAuthenticator { private static final List SUPPORTED_PROVIDERS = Arrays.asList("aws"); @@ -439,7 +439,7 @@ private OidcSaslClient(final MongoCredentialWithCache mongoCredentialWithCache) @Override public byte[] evaluateChallenge(final byte[] challenge) { - return evaluateChallengeInternal(challenge); + return assertNotNull(evaluateChallengeFunction).apply(challenge); } @Override @@ -447,9 +447,6 @@ public boolean isComplete() { return clientIsComplete(); } - public byte[] evaluateChallengeInternal(final byte[] challenge) { - return assertNotNull(evaluateChallengeFunction).apply(challenge); - } } private static String readAwsTokenFromFile() { From 216af3bba1bf630e04d65613dd66aab9418ec4fc Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Fri, 2 Jun 2023 14:45:13 -0600 Subject: [PATCH 18/19] PR fix --- .../main/com/mongodb/internal/connection/OidcAuthenticator.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 9280510c9d9..ad9793f5dd6 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -569,7 +569,7 @@ public static void validateCreateOidcCredential(@Nullable final char[] password) } } - @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PACKAGE) + @VisibleForTesting(otherwise = VisibleForTesting.AccessModifier.PRIVATE) public static void validateBeforeUse(final MongoCredential credential) { String userName = credential.getUserName(); Object providerName = credential.getMechanismProperty(PROVIDER_NAME_KEY, null); From 32fc2eb7a59058909e305a1b264d294ec5930989 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 5 Jun 2023 09:45:01 -0600 Subject: [PATCH 19/19] Apply suggestions from code review Co-authored-by: Valentin Kovalenko --- .../connection/InternalStreamConnection.java | 13 +++++-------- .../internal/connection/OidcAuthenticator.java | 3 +-- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 894ac0a466e..5a3e2b523ad 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -375,19 +375,16 @@ public boolean isClosed() { public T sendAndReceive(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext) { - if (!Authenticator.shouldAuthenticate(authenticator, this.description)) { - return sendAndReceiveInternal(message, decoder, sessionContext, requestContext, operationContext); - } - Supplier retryableOperation = () -> - sendAndReceiveInternal(message, decoder, sessionContext, requestContext, operationContext); + Supplier sendAndReceiveInternal = () -> sendAndReceiveInternal( + message, decoder, sessionContext, requestContext, operationContext); try { - return retryableOperation.get(); + return sendAndReceiveInternal.get(); } catch (MongoCommandException e) { - if (triggersReauthentication(e)) { + if (triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description)) { authenticated.set(false); authenticator.reauthenticate(this); authenticated.set(true); - return retryableOperation.get(); + return sendAndReceiveInternal.get(); } throw e; } diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index ad9793f5dd6..f3c931a433f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -234,8 +234,7 @@ private void authLock(final InternalConnection connection, final ConnectionDescr authenticateUsing(connection, connectionDescription, (challenge) -> evaluate(challenge)); break; } catch (MongoSecurityException e) { - boolean shouldRetry = triggersRetry(e) && shouldRetryHandler(); - if (!shouldRetry) { + if (!(triggersRetry(e) && shouldRetryHandler())) { throw e; } }