diff --git a/driver-core/src/main/com/mongodb/assertions/Assertions.java b/driver-core/src/main/com/mongodb/assertions/Assertions.java index fff23119f7c..0ea3b4eb2d7 100644 --- a/driver-core/src/main/com/mongodb/assertions/Assertions.java +++ b/driver-core/src/main/com/mongodb/assertions/Assertions.java @@ -17,7 +17,6 @@ package com.mongodb.assertions; -import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import java.util.Collection; @@ -50,25 +49,6 @@ public static T notNull(final String name, final T value) { return value; } - /** - * Throw IllegalArgumentException if the value is null. - * - * @param name the parameter name - * @param value the value that should not be null - * @param callback the callback that also is passed the exception if the value is null - * @param the value type - * @return the value - * @throws java.lang.IllegalArgumentException if value is null - */ - public static T notNull(final String name, final T value, final SingleResultCallback callback) { - if (value == null) { - IllegalArgumentException exception = new IllegalArgumentException(name + " can not be null"); - callback.onResult(null, exception); - throw exception; - } - return value; - } - /** * Throw IllegalStateException if the condition if false. * @@ -82,22 +62,6 @@ public static void isTrue(final String name, final boolean condition) { } } - /** - * Throw IllegalStateException if the condition if false. - * - * @param name the name of the state that is being checked - * @param condition the condition about the parameter to check - * @param callback the callback that also is passed the exception if the condition is not true - * @throws java.lang.IllegalStateException if the condition is false - */ - public static void isTrue(final String name, final boolean condition, final SingleResultCallback callback) { - if (!condition) { - IllegalStateException exception = new IllegalStateException("state should be: " + name); - callback.onResult(null, exception); - throw exception; - } - } - /** * Throw IllegalArgumentException if the condition if false. * diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index c06eddcc6dd..767651e788b 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -17,6 +17,8 @@ package com.mongodb.internal; import com.mongodb.MongoInterruptedException; +import com.mongodb.internal.async.AsyncRunnable; +import com.mongodb.internal.async.SingleResultCallback; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.StampedLock; @@ -33,7 +35,23 @@ public static void withLock(final Lock lock, final Runnable action) { }); } - public static V withLock(final StampedLock lock, final Supplier supplier) { + public static void withLockAsync(final StampedLock lock, final AsyncRunnable runnable, + final SingleResultCallback callback) { + long stamp; + try { + stamp = lock.writeLockInterruptibly(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + callback.onResult(null, new MongoInterruptedException("Interrupted waiting for lock", e)); + return; + } + + runnable.thenAlwaysRunAndFinish(() -> { + lock.unlockWrite(stamp); + }, callback); + } + + public static void withLock(final StampedLock lock, final Runnable runnable) { long stamp; try { stamp = lock.writeLockInterruptibly(); @@ -42,7 +60,7 @@ public static V withLock(final StampedLock lock, final Supplier supplier) throw new MongoInterruptedException("Interrupted waiting for lock", e); } try { - return supplier.get(); + runnable.run(); } finally { lock.unlockWrite(stamp); } @@ -55,15 +73,15 @@ public static V withLock(final Lock lock, final Supplier supplier) { public static V checkedWithLock(final Lock lock, final CheckedSupplier supplier) throws E { try { lock.lockInterruptibly(); - try { - return supplier.get(); - } finally { - lock.unlock(); - } } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new MongoInterruptedException("Interrupted waiting for lock", e); } + try { + return supplier.get(); + } finally { + lock.unlock(); + } } private Locks() { diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java new file mode 100644 index 00000000000..b385670ae88 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java @@ -0,0 +1,26 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +/** + * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time + */ +@FunctionalInterface +public interface AsyncConsumer extends AsyncFunction { +} diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java new file mode 100644 index 00000000000..8caf176dce6 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java @@ -0,0 +1,33 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.lang.Nullable; + +/** + * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time + */ +@FunctionalInterface +public interface AsyncFunction { + /** + * This should not be called externally, but should be implemented as a + * lambda. To "finish" an async chain, use one of the "finish" methods. + */ + void unsafeFinish(@Nullable T value, SingleResultCallback callback); +} diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java new file mode 100644 index 00000000000..b9089252f49 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -0,0 +1,158 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.internal.async.function.RetryState; +import com.mongodb.internal.async.function.RetryingAsyncCallbackSupplier; + +import java.util.function.Predicate; +import java.util.function.Supplier; + +/** + * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time + */ +@FunctionalInterface +public interface AsyncRunnable extends AsyncSupplier, AsyncConsumer { + + static AsyncRunnable beginAsync() { + return (c) -> c.onResult(null, null); + } + + /** + * Must be invoked at end of async chain + * @param runnable the sync code to invoke (under non-exceptional flow) + * prior to the callback + * @param callback the callback provided by the method the chain is used in + */ + default void thenRunAndFinish(final Runnable runnable, final SingleResultCallback callback) { + this.finish((r, e) -> { + if (e != null) { + callback.onResult(null, e); + return; + } + try { + runnable.run(); + } catch (Throwable t) { + callback.onResult(null, t); + return; + } + callback.onResult(null, null); + }); + } + + /** + * See {@link #thenRunAndFinish(Runnable, SingleResultCallback)}, but the runnable + * will always be executed, including on the exceptional path. + * @param runnable the runnable + * @param callback the callback + */ + default void thenAlwaysRunAndFinish(final Runnable runnable, final SingleResultCallback callback) { + this.finish((r, e) -> { + try { + runnable.run(); + } catch (Throwable t) { + if (e != null) { + t.addSuppressed(e); + } + callback.onResult(null, t); + return; + } + callback.onResult(null, e); + }); + } + + /** + * @param runnable The async runnable to run after this runnable + * @return the composition of this runnable and the runnable, a runnable + */ + default AsyncRunnable thenRun(final AsyncRunnable runnable) { + return (c) -> { + this.unsafeFinish((r, e) -> { + if (e == null) { + runnable.unsafeFinish(c); + } else { + c.onResult(null, e); + } + }); + }; + } + + /** + * @param condition the condition to check + * @param runnable The async runnable to run after this runnable, + * if and only if the condition is met + * @return the composition of this runnable and the runnable, a runnable + */ + default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRunnable runnable) { + return (callback) -> { + this.unsafeFinish((r, e) -> { + if (e != null) { + callback.onResult(null, e); + return; + } + boolean matched; + try { + matched = condition.get(); + } catch (Throwable t) { + callback.onResult(null, t); + return; + } + if (matched) { + runnable.unsafeFinish(callback); + } else { + callback.onResult(null, null); + } + }); + }; + } + + /** + * @param supplier The supplier to supply using after this runnable + * @return the composition of this runnable and the supplier, a supplier + * @param The return type of the resulting supplier + */ + default AsyncSupplier thenSupply(final AsyncSupplier supplier) { + return (c) -> { + this.unsafeFinish((r, e) -> { + if (e == null) { + supplier.unsafeFinish(c); + } else { + c.onResult(null, e); + } + }); + }; + } + + /** + * @param runnable the runnable to loop + * @param shouldRetry condition under which to retry + * @return the composition of this, and the looping branch + * @see RetryingAsyncCallbackSupplier + */ + default AsyncRunnable thenRunRetryingWhile( + final AsyncRunnable runnable, final Predicate shouldRetry) { + return thenRun(callback -> { + new RetryingAsyncCallbackSupplier( + new RetryState(), + (rs, lastAttemptFailure) -> shouldRetry.test(lastAttemptFailure), + cb -> runnable.finish(cb) // finish is required here, to handle exceptions + ).get(callback); + }); + } +} diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java new file mode 100644 index 00000000000..7b38595bda9 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -0,0 +1,138 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.lang.Nullable; + +import java.util.function.Predicate; + + +/** + * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time + */ +@FunctionalInterface +public interface AsyncSupplier extends AsyncFunction { + /** + * This should not be called externally to this API. It should be + * implemented as a lambda. To "finish" an async chain, use one of + * the "finish" methods. + * + * @see #finish(SingleResultCallback) + */ + void unsafeFinish(SingleResultCallback callback); + + /** + * This method must only be used when this AsyncSupplier corresponds + * to a {@link java.util.function.Supplier} (and is therefore being + * used within an async chain method lambda). + * @param callback the callback + */ + default void getAsync(final SingleResultCallback callback) { + unsafeFinish(callback); + } + + @Override + default void unsafeFinish(@Nullable final Void value, final SingleResultCallback callback) { + unsafeFinish(callback); + } + + /** + * Must be invoked at end of async chain. + * @param callback the callback provided by the method the chain is used in + */ + default void finish(final SingleResultCallback callback) { + final boolean[] callbackInvoked = {false}; + try { + this.unsafeFinish((v, e) -> { + callbackInvoked[0] = true; + callback.onResult(v, e); + }); + } catch (Throwable t) { + if (callbackInvoked[0]) { + throw t; + } else { + callback.onResult(null, t); + } + } + } + + /** + * @param function The async function to run after this supplier + * @return the composition of this supplier and the function, a supplier + * @param The return type of the resulting supplier + */ + default AsyncSupplier thenApply(final AsyncFunction function) { + return (c) -> { + this.unsafeFinish((v, e) -> { + if (e == null) { + function.unsafeFinish(v, c); + } else { + c.onResult(null, e); + } + }); + }; + } + + + /** + * @param consumer The async consumer to run after this supplier + * @return the composition of this supplier and the consumer, a runnable + */ + default AsyncRunnable thenConsume(final AsyncConsumer consumer) { + return (c) -> { + this.unsafeFinish((v, e) -> { + if (e == null) { + consumer.unsafeFinish(v, c); + } else { + c.onResult(null, e); + } + }); + }; + } + + /** + * @param errorCheck A check, comparable to a catch-if/otherwise-rethrow + * @param supplier The branch to execute if the error matches + * @return The composition of this, and the conditional branch + */ + default AsyncSupplier onErrorIf( + final Predicate errorCheck, + final AsyncSupplier supplier) { + return (callback) -> this.unsafeFinish((r, e) -> { + if (e == null) { + callback.onResult(r, null); + return; + } + boolean errorMatched; + try { + errorMatched = errorCheck.test(e); + } catch (Throwable t) { + t.addSuppressed(e); + callback.onResult(null, t); + return; + } + if (errorMatched) { + supplier.unsafeFinish(callback); + } else { + callback.onResult(null, e); + } + }); + } + +} diff --git a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java index 7304a9ef9b5..02fdbdf9699 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java @@ -32,17 +32,4 @@ public interface AsyncCallbackRunnable { */ void run(SingleResultCallback callback); - /** - * Converts this {@link AsyncCallbackSupplier} to {@link AsyncCallbackSupplier}{@code }. - */ - default AsyncCallbackSupplier asSupplier() { - return this::run; - } - - /** - * @see AsyncCallbackSupplier#whenComplete(Runnable) - */ - default AsyncCallbackRunnable whenComplete(final Runnable after) { - return callback -> asSupplier().whenComplete(after).get(callback); - } } diff --git a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java index 9ebe02f5aa7..92233a072be 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java @@ -84,6 +84,13 @@ public RetryingAsyncCallbackSupplier( this.asyncFunction = asyncFunction; } + public RetryingAsyncCallbackSupplier( + final RetryState state, + final BiPredicate retryPredicate, + final AsyncCallbackSupplier asyncFunction) { + this(state, (previouslyChosenFailure, lastAttemptFailure) -> lastAttemptFailure, retryPredicate, asyncFunction); + } + @Override public void get(final SingleResultCallback callback) { /* `asyncFunction` and `callback` are the only externally provided pieces of code for which we do not need to care about diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index 45e0b078452..232eeb45049 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -27,6 +27,7 @@ import com.mongodb.lang.Nullable; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; /** *

This class is not part of the public API and may be removed or changed at any time

@@ -104,4 +105,10 @@ public void reauthenticate(final InternalConnection connection) { authenticate(connection, connection.getDescription()); } + public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { + beginAsync().thenRun((c) -> { + authenticateAsync(connection, connection.getDescription(), c); + }).finish(callback); + } + } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java index 59e34404b1f..66ff3e51c16 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java @@ -50,7 +50,7 @@ public interface InternalConnection extends BufferProvider { ServerDescription getInitialServerDescription(); /** - * Opens the connection so its ready for use + * Opens the connection so its ready for use. Will perform a handshake. */ void open(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 5a3e2b523ad..e9fa002bb38 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -44,6 +44,7 @@ import com.mongodb.connection.StreamFactory; import com.mongodb.event.CommandListener; import com.mongodb.internal.VisibleForTesting; +import com.mongodb.internal.async.AsyncSupplier; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.diagnostics.logging.Logger; import com.mongodb.internal.diagnostics.logging.Loggers; @@ -72,9 +73,12 @@ import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertNull; import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback; +import static com.mongodb.internal.connection.Authenticator.shouldAuthenticate; import static com.mongodb.internal.connection.CommandHelper.HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO_LOWER; @@ -224,7 +228,7 @@ public int getGeneration() { @Override public void open() { - isTrue("Open already called", stream == null); + assertNull(stream); stream = streamFactory.create(getServerAddressWithResolver()); try { stream.open(); @@ -246,7 +250,7 @@ public void open() { @Override public void openAsync(final SingleResultCallback callback) { - isTrue("Open already called", stream == null, callback); + assertNull(stream); try { stream = streamFactory.create(getServerAddressWithResolver()); stream.openAsync(new AsyncCompletionHandler() { @@ -380,17 +384,48 @@ public T sendAndReceive(final CommandMessage message, final Decoder decod try { return sendAndReceiveInternal.get(); } catch (MongoCommandException e) { - if (triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description)) { - authenticated.set(false); - authenticator.reauthenticate(this); - authenticated.set(true); - return sendAndReceiveInternal.get(); + if (reauthenticationIsTriggered(e)) { + return reauthenticateAndRetry(sendAndReceiveInternal); } throw e; } } - public static boolean triggersReauthentication(@Nullable final Throwable t) { + @Override + public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, + final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { + + AsyncSupplier sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal( + message, decoder, sessionContext, requestContext, operationContext, c); + beginAsync().thenSupply(c -> { + sendAndReceiveAsyncInternal.getAsync(c); + }).onErrorIf(e -> reauthenticationIsTriggered(e), c -> { + reauthenticateAndRetryAsync(sendAndReceiveAsyncInternal, c); + }).finish(callback); + } + + private T reauthenticateAndRetry(final Supplier operation) { + authenticated.set(false); + assertNotNull(authenticator).reauthenticate(this); + authenticated.set(true); + return operation.get(); + } + + private void reauthenticateAndRetryAsync(final AsyncSupplier operation, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + authenticated.set(false); + assertNotNull(authenticator).reauthenticateAsync(this, c); + }).thenSupply((c) -> { + authenticated.set(true); + operation.getAsync(c); + }).finish(callback); + } + + public boolean reauthenticationIsTriggered(@Nullable final Throwable t) { + if (!shouldAuthenticate(authenticator, this.description)) { + return false; + } if (t instanceof MongoCommandException) { MongoCommandException e = (MongoCommandException) t; return e.getErrorCode() == 391; @@ -518,11 +553,8 @@ private T receiveCommandMessageResponse(final Decoder decoder, } } - @Override - public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, + private void sendAndReceiveAsyncInternal(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); - if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); return; @@ -637,7 +669,7 @@ public void sendMessage(final List byteBuffers, final int lastRequestId @Override public ResponseBuffers receiveMessage(final int responseTo) { - notNull("stream is open", stream); + assertNotNull(stream); if (isClosed()) { throw new MongoSocketClosedException("Cannot read from a closed stream", getServerAddress()); } @@ -655,8 +687,9 @@ private ResponseBuffers receiveMessageWithAdditionalTimeout(final int additional } @Override - public void sendMessageAsync(final List byteBuffers, final int lastRequestId, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); + public void sendMessageAsync(final List byteBuffers, final int lastRequestId, + final SingleResultCallback callback) { + assertNotNull(stream); if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); @@ -688,7 +721,7 @@ public void failed(final Throwable t) { @Override public void receiveMessageAsync(final int responseTo, final SingleResultCallback callback) { - isTrue("stream is open", stream != null, callback); + assertNotNull(stream); if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index f3c931a433f..18a54311496 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -31,6 +31,7 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.Locks; import com.mongodb.internal.Timeout; +import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.VisibleForTesting; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; @@ -66,6 +67,7 @@ import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; import static java.lang.String.format; @@ -180,32 +182,66 @@ private OidcRequestCallback getRequestCallback() { @Override public void reauthenticate(final InternalConnection connection) { - // method must only be called after original handshake: assertTrue(connection.opened()); authLock(connection, connection.getDescription()); } + @Override + public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + assertTrue(connection.opened()); + authLockAsync(connection, connection.getDescription(), c); + }).finish(callback); + } + @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { - // method must only be called during original handshake: assertFalse(connection.opened()); - // this method "wraps" the default authentication method in custom OIDC retry logic String accessToken = getValidCachedAccessToken(); if (accessToken != null) { - try { - authenticateUsing(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); - } catch (MongoSecurityException e) { - if (triggersRetry(e)) { - authLock(connection, connectionDescription); - } else { - throw e; - } - } + authenticateOptimistically(connection, connectionDescription, accessToken); } else { authLock(connection, connectionDescription); } } + @Override + void authenticateAsync(final InternalConnection connection, final ConnectionDescription connectionDescription, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + assertFalse(connection.opened()); + String accessToken = getValidCachedAccessToken(); + if (accessToken != null) { + authenticateOptimisticallyAsync(connection, connectionDescription, accessToken, c); + } else { + authLockAsync(connection, connectionDescription, c); + } + }).finish(callback); + } + + private void authenticateOptimistically(final InternalConnection connection, + final ConnectionDescription connectionDescription, final String accessToken) { + try { + authenticateUsingFunction(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); + } catch (MongoSecurityException e) { + if (triggersRetry(e)) { + authLock(connection, connectionDescription); + } else { + throw e; + } + } + } + + private void authenticateOptimisticallyAsync(final InternalConnection connection, + final ConnectionDescription connectionDescription, final String accessToken, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + authenticateUsingFunctionAsync(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken), c); + }).onErrorIf(e -> triggersRetry(e), c -> { + authLockAsync(connection, connectionDescription, c); + }).finish(callback); + } + private static boolean triggersRetry(@Nullable final Throwable t) { if (t instanceof MongoSecurityException) { MongoSecurityException e = (MongoSecurityException) t; @@ -218,7 +254,14 @@ private static boolean triggersRetry(@Nullable final Throwable t) { return false; } - private void authenticateUsing( + private void authenticateUsingFunctionAsync(final InternalConnection connection, + final ConnectionDescription connectionDescription, final Function evaluateChallengeFunction, + final SingleResultCallback callback) { + this.evaluateChallengeFunction = evaluateChallengeFunction; + super.authenticateAsync(connection, connectionDescription, callback); + } + + private void authenticateUsingFunction( final InternalConnection connection, final ConnectionDescription connectionDescription, final Function evaluateChallengeFunction) { @@ -226,23 +269,33 @@ private void authenticateUsing( super.authenticate(connection, connectionDescription); } - private void authLock(final InternalConnection connection, final ConnectionDescription connectionDescription) { + private void authLock(final InternalConnection connection, final ConnectionDescription description) { fallbackState = FallbackState.INITIAL; Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { while (true) { try { - authenticateUsing(connection, connectionDescription, (challenge) -> evaluate(challenge)); + authenticateUsingFunction(connection, description, (challenge) -> evaluate(challenge)); break; } catch (MongoSecurityException e) { - if (!(triggersRetry(e) && shouldRetryHandler())) { - throw e; + if (triggersRetry(e) && shouldRetryHandler()) { + continue; } + throw e; } } - return null; }); } + private void authLockAsync(final InternalConnection connection, final ConnectionDescription description, + final SingleResultCallback callback) { + fallbackState = FallbackState.INITIAL; + Locks.withLockAsync(getMongoCredentialWithCache().getOidcLock(), + beginAsync().thenRunRetryingWhile( + c -> authenticateUsingFunctionAsync(connection, description, (challenge) -> evaluate(challenge), c), + e -> triggersRetry(e) && shouldRetryHandler() + ), callback); + } + private byte[] evaluate(final byte[] challenge) { if (isAutomaticAuthentication()) { return prepareAwsTokenFromFileAsJwt(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index e399b00bea8..d3dd94ae53f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -128,11 +128,9 @@ private void throwIfSaslClientIsNull(@Nullable final SaslClient saslClient) { } private BsonDocument getNextSaslResponse(final SaslClient saslClient, final InternalConnection connection) { - if (!connection.opened()) { - BsonDocument response = getSpeculativeAuthenticateResponse(); - if (response != null) { - return response; - } + BsonDocument response = connection.opened() ? null : getSpeculativeAuthenticateResponse(); + if (response != null) { + return response; } try { @@ -147,7 +145,7 @@ private void getNextSaslResponseAsync(final SaslClient saslClient, final Interna final SingleResultCallback callback) { SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); try { - BsonDocument response = getSpeculativeAuthenticateResponse(); + BsonDocument response = connection.opened() ? null : getSpeculativeAuthenticateResponse(); if (response == null) { byte[] serverResponse = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null); sendSaslStartAsync(serverResponse, connection, (result, t) -> { diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java new file mode 100644 index 00000000000..375f7b5c555 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java @@ -0,0 +1,657 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.internal.async; + +import com.mongodb.client.TestListener; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +final class AsyncFunctionsTest { + private final TestListener listener = new TestListener(); + private final InvocationTracker invocationTracker = new InvocationTracker(); + + @Test + void testVariations1() { + /* + In our async code: + 1. a callback is provided as a method parameter + 2. at least one sync method must be converted to async + + To use this API: + 1. start an async chain using the "beginAsync" static method + 2. use an appropriate chaining method (then...), which will provide "c" + 3. copy all sync code to that method + 4. at the async method, pass in "c" and start a new chaining method + 5. provide the original "callback" at the end of the chain via "finish" + + Async methods MUST be preceded by unaffected "plain" sync code (sync + code with no async counterpart), and this code MUST reside above the + affected method, as it appears in the sync code. Plain code after + the sync method should be supplied via one of the "finally" variants. + Safe "shared" plain code (variable and lambda declarations) which cannot + throw, may remain outside the chained invocations, for convenience. + + Plain sync code MAY throw exceptions, and SHOULD NOT attempt to handle + them asynchronously. The exceptions will be caught and handled by the + chaining methods that contain this sync code. + + Each async lambda MUST invoke its async method with "c", and MUST return + immediately after invoking that method. It MUST NOT, for example, have + a catch or finally (including close on try-with-resources) after the + invocation of the sync method. + + A braced lambda body (with no linebreak before "."), as shown below, + should be used, as this will be consistent with other usages, and allows + the async code to be more easily compared to the sync code. + */ + + // the number of expected variations is often: 1 + N methods invoked + // 1 variation with no exceptions, and N per an exception in each method + assertBehavesSameVariations(2, + () -> { + // single sync method invocations... + sync(1); + }, + (callback) -> { + // ...become a single async invocation, wrapped in begin-thenRun/finish: + beginAsync().thenRun(c -> { + async(1, c); + }).finish(callback); + }); + } + + @Test + void testVariations2() { + // tests pairs + // converting: plain-sync, sync-plain, sync-sync + // (plain-plain does not need an async chain) + + assertBehavesSameVariations(3, + () -> { + // plain (unaffected) invocations... + plain(1); + sync(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + // ...are preserved above affected methods + plain(1); + async(2, c); + }).finish(callback); + }); + + assertBehavesSameVariations(3, + () -> { + // when a plain invocation follows an affected method... + sync(1); + plain(2); + }, + (callback) -> { + // ...it is moved to its own block + beginAsync().thenRun(c -> { + async(1, c); + }).thenRunAndFinish(() -> { + plain(2); + }, callback); + }); + + assertBehavesSameVariations(3, + () -> { + // when an affected method follows an affected method + sync(1); + sync(2); + }, + (callback) -> { + // ...it is moved to its own block + beginAsync().thenRun(c -> { + async(1, c); + }).thenRun(c -> { + async(2, c); + }).finish(callback); + }); + } + + @Test + void testVariations4() { + // tests the sync-sync pair with preceding and ensuing plain methods: + assertBehavesSameVariations(5, + () -> { + plain(11); + sync(1); + plain(22); + sync(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(11); + async(1, c); + }).thenRun(c -> { + plain(22); + async(2, c); + }).finish(callback); + }); + + assertBehavesSameVariations(5, + () -> { + sync(1); + plain(11); + sync(2); + plain(22); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).thenRun(c -> { + plain(11); + async(2, c); + }).thenRunAndFinish(() ->{ + plain(22); + }, callback); + }); + } + + @Test + void testSupply() { + assertBehavesSameVariations(4, + () -> { + sync(0); + plain(1); + return syncReturns(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(0, c); + }).thenSupply(c -> { + plain(1); + asyncReturns(2, c); + }).finish(callback); + }); + } + + @SuppressWarnings("ConstantConditions") + @Test + void testFullChain() { + // tests a chain: runnable, producer, function, function, consumer + + assertBehavesSameVariations(14, + () -> { + plain(90); + sync(0); + plain(91); + sync(1); + plain(92); + int v = syncReturns(2); + plain(93); + v = syncReturns(v + 1); + plain(94); + v = syncReturns(v + 10); + plain(95); + sync(v + 100); + plain(96); + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(90); + async(0, c); + }).thenRun(c -> { + plain(91); + async(1, c); + }).thenSupply(c -> { + plain(92); + asyncReturns(2, c); + }).thenApply((v, c) -> { + plain(93); + asyncReturns(v + 1, c); + }).thenApply((v, c) -> { + plain(94); + asyncReturns(v + 10, c); + }).thenConsume((v, c) -> { + plain(95); + async(v + 100, c); + }).thenRunAndFinish(() -> { + plain(96); + }, callback); + }); + } + + @Test + void testVariationsBranching() { + assertBehavesSameVariations(5, + () -> { + if (plainTest(1)) { + sync(2); + } else { + sync(3); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + if (plainTest(1)) { + async(2, c); + } else { + async(3, c); + } + }).finish(callback); + }); + + // 2 : fail on first sync, fail on test + // 3 : true test, sync2, sync3 + // 2 : false test, sync3 + // 7 total + assertBehavesSameVariations(7, + () -> { + sync(0); + if (plainTest(1)) { + sync(2); + } + sync(3); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(0, c); + }).thenRunIf(() -> plainTest(1), c -> { + async(2, c); + }).thenRun(c -> { + async(3, c); + }).finish(callback); + }); + + // an additional affected method within the "if" branch + assertBehavesSameVariations(8, + () -> { + sync(0); + if (plainTest(1)) { + sync(21); + sync(22); + } + sync(3); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(0, c); + }).thenRunIf(() -> plainTest(1), + beginAsync().thenRun(c -> { + async(21, c); + }).thenRun((c) -> { + async(22, c); + }) + ).thenRun(c -> { + async(3, c); + }).finish(callback); + }); + } + + @Test + void testErrorIf() { + // thenSupply: + assertBehavesSameVariations(5, + () -> { + try { + return syncReturns(1); + } catch (Exception e) { + if (e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1")) { + return syncReturns(2); + } else { + throw e; + } + } + }, + (callback) -> { + beginAsync().thenSupply(c -> { + asyncReturns(1, c); + }).onErrorIf(e -> e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1"), c -> { + asyncReturns(2, c); + }).finish(callback); + }); + + // thenRun: + assertBehavesSameVariations(5, + () -> { + try { + sync(1); + } catch (Exception e) { + if (e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1")) { + sync(2); + } else { + throw e; + } + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).onErrorIf(e -> e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1"), c -> { + async(2, c); + }).finish(callback); + }); + } + + @Test + void testLoop() { + assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1, + () -> { + while (true) { + try { + sync(plainTest(0) ? 1 : 2); + break; + } catch (RuntimeException e) { + if (e.getMessage().equals("exception-1")) { + continue; + } + throw e; + } + } + }, + (callback) -> { + beginAsync().thenRunRetryingWhile( + c -> sync(plainTest(0) ? 1 : 2), + e -> e.getMessage().equals("exception-1") + ).finish(callback); + }); + } + + @Test + void testFinally() { + // (in try: normal flow + exception + exception) * (in finally: normal + exception) = 6 + assertBehavesSameVariations(6, + () -> { + try { + plain(1); + sync(2); + } finally { + plain(3); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(1); + async(2, c); + }).thenAlwaysRunAndFinish(() -> { + plain(3); + }, callback); + }); + } + + @Test + void testUsedAsLambda() { + assertBehavesSameVariations(4, + () -> { + Supplier s = () -> syncReturns(9); + sync(0); + plain(1); + return s.get(); + }, + (callback) -> { + AsyncSupplier s = (c) -> asyncReturns(9, c); + beginAsync().thenRun(c -> { + async(0, c); + }).thenSupply((c) -> { + plain(1); + s.getAsync(c); + }).finish(callback); + }); + } + + @Test + void testVariables() { + assertBehavesSameVariations(3, + () -> { + int something; + something = 90; + sync(something); + something = something + 10; + sync(something); + }, + (callback) -> { + // Certain variables may need to be shared; these can be + // declared (but not initialized) outside the async chain. + // Any container works (atomic allowed but not needed) + final int[] something = new int[1]; + beginAsync().thenRun(c -> { + something[0] = 90; + async(something[0], c); + }).thenRun((c) -> { + something[0] = something[0] + 10; + async(something[0], c); + }).finish(callback); + }); + } + + @Test + void testInvalid() { + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + throw new IllegalStateException("must not cause second callback invocation"); + }).finish((v, e) -> {}); + }); + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + }).finish((v, e) -> { + throw new IllegalStateException("must not cause second callback invocation"); + }); + }); + } + + // invoked methods: + + private void plain(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else { + listener.add("plain-success-" + i); + } + } + + private boolean plainTest(final int i) { + int cur = invocationTracker.getNextOption(3); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else if (cur == 1) { + listener.add("plain-false-" + i); + return false; + } else { + listener.add("plain-true-" + i); + return true; + } + } + + private void sync(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-success-" + i); + } + } + + private Integer syncReturns(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-success-" + i); + return i; + } + } + + private void async(final int i, final SingleResultCallback callback) { + try { + sync(i); + callback.onResult(null, null); + } catch (Throwable t) { + callback.onResult(null, t); + } + } + + private void asyncReturns(final int i, final SingleResultCallback callback) { + try { + callback.onResult(syncReturns(i), null); + } catch (Throwable t) { + callback.onResult(null, t); + } + } + + // assert methods: + + private void assertBehavesSameVariations(final int expectedVariations, final Runnable sync, + final Consumer> async) { + assertBehavesSameVariations( + expectedVariations, + () -> { + sync.run(); + return null; + }, + (c) -> { + async.accept((v, e) -> c.onResult(v, e)); + }); + } + + private void assertBehavesSameVariations(final int expectedVariations, final Supplier sync, + final Consumer> async) { + invocationTracker.reset(); + do { + invocationTracker.startInitialStep(); + assertBehavesSame( + sync, + () -> invocationTracker.startMatchStep(), + async); + + } while (invocationTracker.countDown()); + assertEquals(expectedVariations, invocationTracker.getVariationCount()); + } + + private void assertBehavesSame(final Supplier sync, final Runnable between, final Consumer> async) { + T expectedValue = null; + Throwable expectedException = null; + try { + expectedValue = sync.get(); + } catch (Throwable e) { + expectedException = e; + } + List expectedEvents = listener.getEventStrings(); + + listener.clear(); + between.run(); + + AtomicReference actualValue = new AtomicReference<>(); + AtomicReference actualException = new AtomicReference<>(); + try { + async.accept((v, e) -> { + actualValue.set(v); + actualException.set(e); + }); + } catch (Throwable e) { + fail("async threw instead of using callback"); + } + + // The following code can be used to debug variations: + // System.out.println("==="); + // System.out.println(listener.getEventStrings()); + // System.out.println("==="); + + assertEquals(expectedEvents, listener.getEventStrings(), "steps did not match"); + assertEquals(expectedValue, actualValue.get()); + assertEquals(expectedException == null, actualException.get() == null); + if (expectedException != null) { + assertEquals(expectedException.getMessage(), actualException.get().getMessage()); + assertEquals(expectedException.getClass(), actualException.get().getClass()); + } + + listener.clear(); + } + + /** + * Tracks invocations: allows testing of all variations of a method calls + */ + private static class InvocationTracker { + public static final int DEPTH_LIMIT = 50; + private final List invocationResults = new ArrayList<>(); + private boolean isMatchStep = false; // vs initial step + private int item = 0; + private int variationCount = 0; + + public void reset() { + variationCount = 0; + } + + public void startInitialStep() { + variationCount++; + isMatchStep = false; + item = -1; + } + + public int getNextOption(final int myOptionsSize) { + item++; + if (item >= invocationResults.size()) { + if (isMatchStep) { + fail("result should have been pre-initialized: steps may not match"); + } + if (isWithinDepthLimit()) { + invocationResults.add(myOptionsSize - 1); + } else { + invocationResults.add(0); // choose "0" option, usually an exception + } + } + return invocationResults.get(item); + } + + public void startMatchStep() { + isMatchStep = true; + item = -1; + } + + private boolean countDown() { + while (!invocationResults.isEmpty()) { + int lastItemIndex = invocationResults.size() - 1; + int lastItem = invocationResults.get(lastItemIndex); + if (lastItem > 0) { + // count current digit down by 1, until 0 + invocationResults.set(lastItemIndex, lastItem - 1); + return true; + } else { + // current digit completed, remove (move left) + invocationResults.remove(lastItemIndex); + } + } + return false; + } + + public int getVariationCount() { + return variationCount; + } + + public boolean isWithinDepthLimit() { + return invocationResults.size() < DEPTH_LIMIT; + } + } +} diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java new file mode 100644 index 00000000000..b18825e89a8 --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.reactivestreams.client.MongoClients; +import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; +import org.junit.jupiter.api.Test; +import reactivestreams.helpers.SubscriberHelpers; + +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static util.ThreadTestHelpers.executeAll; + +public class OidcAuthenticationAsyncProseTests extends OidcAuthenticationProseTests { + + @Override + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return new SyncMongoClient(MongoClients.create(settings)); + } + + @Test + public void testNonblockingCallbacks() { + // not a prose spec test + delayNextFind(); + + int simulatedDelayMs = 100; + TestCallback requestCallback = createCallback().setExpired().setDelayMs(simulatedDelayMs); + TestCallback refreshCallback = createCallback().setDelayMs(simulatedDelayMs); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, requestCallback, refreshCallback); + + try (com.mongodb.reactivestreams.client.MongoClient client = MongoClients.create(clientSettings)) { + executeAll(2, () -> { + SubscriberHelpers.OperationSubscriber subscriber = new SubscriberHelpers.OperationSubscriber<>(); + long t1 = System.nanoTime(); + client.getDatabase("test") + .getCollection("test") + .find() + .first() + .subscribe(subscriber); + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t1); + + assertTrue(elapsedMs < simulatedDelayMs); + subscriber.get(); + }); + + // ensure both callbacks have been tested + assertEquals(1, requestCallback.getInvocations()); + assertEquals(1, refreshCallback.getInvocations()); + } + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index 74e95d7e253..368e1342e1f 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -598,12 +598,16 @@ public void test6p1ReauthenticationSucceeds() { assertEquals(0, onRefresh.getInvocations()); assertEquals(Arrays.asList( + // speculative: "isMaster started", "isMaster succeeded", + // onRequest: "onRequest invoked", "read access token: test_user1", + // jwt from onRequest: "saslContinue started", "saslContinue succeeded", + // ensuing find: "find started", "find succeeded" ), listener.getEventStrings()); @@ -624,10 +628,12 @@ public void test6p1ReauthenticationSucceeds() { assertEquals(Arrays.asList( "find started", "find failed", + // find has triggered 391, and cleared the access token; fall back to refresh: "onRefresh invoked", "read access token: test_user1", "saslStart started", "saslStart succeeded", + // find retry succeeds: "find started", "find succeeded" ), listener.getEventStrings());