From 9873b4a779be10140e0875648a58cac4e30ab90b Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Wed, 31 May 2023 11:42:51 -0600 Subject: [PATCH 01/12] Implement OIDC auth for async JAVA-4981 --- .../src/main/com/mongodb/internal/Locks.java | 32 +- .../mongodb/internal/async/AsyncRunnable.java | 165 ++++ .../mongodb/internal/async/AsyncSupplier.java | 60 ++ .../RetryingAsyncCallbackSupplier.java | 10 + .../internal/connection/Authenticator.java | 5 + .../connection/InternalStreamConnection.java | 32 +- .../connection/OidcAuthenticator.java | 54 +- .../connection/SaslAuthenticator.java | 10 +- .../internal/async/AsyncRunnableTest.java | 702 ++++++++++++++++++ .../OidcAuthenticationAsyncProseTests.java | 70 ++ .../OidcAuthenticationProseTests.java | 6 + 11 files changed, 1129 insertions(+), 17 deletions(-) create mode 100644 driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java create mode 100644 driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java create mode 100644 driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java create mode 100644 driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/OidcAuthenticationAsyncProseTests.java diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index c06eddcc6dd..a76e544e39f 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -17,6 +17,8 @@ package com.mongodb.internal; import com.mongodb.MongoInterruptedException; +import com.mongodb.internal.async.AsyncRunnable; +import com.mongodb.internal.async.SingleResultCallback; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.StampedLock; @@ -33,6 +35,26 @@ public static void withLock(final Lock lock, final Runnable action) { }); } + public static void withLockAsync(final StampedLock lock, final AsyncRunnable runnable, + final SingleResultCallback callback) { + long stamp; + try { + stamp = lock.writeLockInterruptibly(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + try { + throw new MongoInterruptedException("Interrupted waiting for lock", e); + } catch (MongoInterruptedException mie) { + callback.onResult(null, mie); + return; + } + } + + runnable.completeAlways(() -> { + lock.unlockWrite(stamp); + }, callback); + } + public static V withLock(final StampedLock lock, final Supplier supplier) { long stamp; try { @@ -55,15 +77,15 @@ public static V withLock(final Lock lock, final Supplier supplier) { public static V checkedWithLock(final Lock lock, final CheckedSupplier supplier) throws E { try { lock.lockInterruptibly(); - try { - return supplier.get(); - } finally { - lock.unlock(); - } } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new MongoInterruptedException("Interrupted waiting for lock", e); } + try { + return supplier.get(); + } finally { + lock.unlock(); + } } private Locks() { diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java new file mode 100644 index 00000000000..7a845b8bf96 --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -0,0 +1,165 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.internal.async.function.RetryState; +import com.mongodb.internal.async.function.RetryingAsyncCallbackSupplier; + +import java.util.function.Function; +import java.util.function.Predicate; + +/** + * See AsyncRunnableTest for usage + */ +public interface AsyncRunnable { + + static AsyncRunnable startAsync() { + return (c) -> c.onResult(null, null); + } + + /** + * Must be invoked at end of async chain + * @param callback the callback provided by the method the chain is used in + */ + void complete(SingleResultCallback callback); // NoResultCallback + + /** + * Must be invoked at end of async chain + * @param runnable the sync code to invoke (under non-exceptional flow) + * prior to the callback + * @param callback the callback provided by the method the chain is used in + */ + default void complete(final Runnable runnable, final SingleResultCallback callback) { + this.complete((r, e) -> { + if (e != null) { + callback.onResult(null, e); + return; + } + try { + runnable.run(); + } catch (Throwable t) { + callback.onResult(null, t); + return; + } + callback.onResult(null, null); + }); + } + + /** + * See {@link #complete(Runnable, SingleResultCallback)}, but the runnable + * will always be executed, including on the exceptional path. + * @param runnable the runnable + * @param callback the callback + */ + default void completeAlways(final Runnable runnable, final SingleResultCallback callback) { + this.complete((r, e) -> { + try { + runnable.run(); + } catch (Throwable t) { + callback.onResult(null, t); + return; + } + callback.onResult(r, e); + }); + } + + /** + * @param runnable The async runnable to run after this one + * @return the composition of this and the runnable + */ + default AsyncRunnable run(final AsyncRunnable runnable) { + return (c) -> { + this.complete((r, e) -> { + if (e != null) { + c.onResult(null, e); + return; + } + try { + runnable.complete(c); + } catch (Throwable t) { + c.onResult(null, t); + } + }); + }; + } + + /** + * @param supplier The supplier to supply using after this runnable. + * @return the composition of this runnable and the supplier + * @param The return type of the supplier + */ + default AsyncSupplier supply(final AsyncSupplier supplier) { + return (c) -> { + this.complete((r, e) -> { + if (e != null) { + c.onResult(null, e); + return; + } + try { + supplier.complete(c); + } catch (Throwable t) { + c.onResult(null, t); + } + }); + }; + } + + /** + * @param errorCheck A check, comparable to a catch-if/otherwise-rethrow + * @param runnable The branch to execute if the error matches + * @return The composition of this, and the conditional branch + */ + default AsyncRunnable onErrorIf( + final Function errorCheck, + final AsyncRunnable runnable) { + return (callback) -> this.complete((r, e) -> { + if (e == null) { + callback.onResult(r, null); + return; + } + try { + Boolean check = errorCheck.apply(e); + if (check) { + runnable.complete(callback); + return; + } + } catch (Throwable t) { + callback.onResult(null, t); + return; + } + callback.onResult(r, e); + }); + } + + /** + * @see RetryingAsyncCallbackSupplier + * @param shouldRetry condition under which to retry + * @param runnable the runnable to loop + * @return the composition of this, and the looping branch + */ + default AsyncRunnable runRetryingWhen( + final Predicate shouldRetry, + final AsyncRunnable runnable) { + return this.run(callback -> { + new RetryingAsyncCallbackSupplier( + new RetryState(), + (rs, lastAttemptFailure) -> shouldRetry.test(lastAttemptFailure), + cb -> runnable.complete(cb) + ).get(callback); + }); + } +} diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java new file mode 100644 index 00000000000..236dfa1905c --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -0,0 +1,60 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import java.util.function.Function; + +/** + * See AsyncRunnableTest for usage + */ +public interface AsyncSupplier { + + /** + * Must be invoked at end of async chain + * @param callback the callback provided by the method the chain is used in + */ + void complete(SingleResultCallback callback); + + /** + * @see AsyncRunnable#onErrorIf(Function, AsyncRunnable). + * + * @param errorCheck A check, comparable to a catch-if/otherwise-rethrow + * @param supplier The branch to execute if the error matches + * @return The composition of this, and the conditional branch + */ + default AsyncSupplier onErrorIf( + final Function errorCheck, + final AsyncSupplier supplier) { + return (callback) -> this.complete((r, e) -> { + if (e == null) { + callback.onResult(r, null); + return; + } + try { + Boolean check = errorCheck.apply(e); + if (check) { + supplier.complete(callback); + return; + } + } catch (Throwable t) { + callback.onResult(null, t); + return; + } + callback.onResult(r, e); + }); + } +} diff --git a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java index 9ebe02f5aa7..0375efd539c 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java @@ -84,6 +84,16 @@ public RetryingAsyncCallbackSupplier( this.asyncFunction = asyncFunction; } + public RetryingAsyncCallbackSupplier( + final RetryState state, + final BiPredicate retryPredicate, + final AsyncCallbackSupplier asyncFunction) { + this.state = state; + this.retryPredicate = retryPredicate; + this.failedResultTransformer = (previouslyChosenFailure, lastAttemptFailure) -> lastAttemptFailure; + this.asyncFunction = asyncFunction; + } + @Override public void get(final SingleResultCallback callback) { /* `asyncFunction` and `callback` are the only externally provided pieces of code for which we do not need to care about diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index 45e0b078452..0c002dae9fc 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -104,4 +104,9 @@ public void reauthenticate(final InternalConnection connection) { authenticate(connection, connection.getDescription()); } + public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { + throw new UnsupportedOperationException( + "Reauthentication requested by server but is not supported by specified mechanism."); + } + } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 5a3e2b523ad..c339477c6a2 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -44,6 +44,7 @@ import com.mongodb.connection.StreamFactory; import com.mongodb.event.CommandListener; import com.mongodb.internal.VisibleForTesting; +import com.mongodb.internal.async.AsyncSupplier; import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.diagnostics.logging.Logger; import com.mongodb.internal.diagnostics.logging.Loggers; @@ -74,6 +75,7 @@ import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.startAsync; import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback; import static com.mongodb.internal.connection.CommandHelper.HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO; @@ -390,6 +392,31 @@ public T sendAndReceive(final CommandMessage message, final Decoder decod } } + + @Override + public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, + final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { + notNull("stream is open", stream, callback); + + AsyncSupplier sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal( + message, decoder, sessionContext, requestContext, operationContext, c); + + if (!Authenticator.shouldAuthenticate(authenticator, this.description)) { + sendAndReceiveAsyncInternal.complete(callback); + return; + } + + sendAndReceiveAsyncInternal.onErrorIf(e -> triggersReauthentication(e), startAsync() + .run(c -> { + authenticated.set(false); + authenticator.reauthenticateAsync(this, c); + }).supply((c) -> { + authenticated.set(true); + sendAndReceiveAsyncInternal.complete(c); + })) + .complete(callback); + } + public static boolean triggersReauthentication(@Nullable final Throwable t) { if (t instanceof MongoCommandException) { MongoCommandException e = (MongoCommandException) t; @@ -518,11 +545,8 @@ private T receiveCommandMessageResponse(final Decoder decoder, } } - @Override - public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, + private void sendAndReceiveAsyncInternal(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); - if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); return; diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index f3c931a433f..8496070b660 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -31,6 +31,7 @@ import com.mongodb.connection.ConnectionDescription; import com.mongodb.internal.Locks; import com.mongodb.internal.Timeout; +import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.internal.VisibleForTesting; import com.mongodb.lang.Nullable; import org.bson.BsonDocument; @@ -66,6 +67,7 @@ import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; +import static com.mongodb.internal.async.AsyncRunnable.startAsync; import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; import static java.lang.String.format; @@ -185,6 +187,15 @@ public void reauthenticate(final InternalConnection connection) { authLock(connection, connection.getDescription()); } + @Override + public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { + assertTrue(connection.opened()); + fallbackState = FallbackState.INITIAL; + startAsync().run(c -> { + authLockAsync(connection, connection.getDescription(), c); + }).complete(callback); + } + @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { // method must only be called during original handshake: @@ -206,6 +217,26 @@ public void authenticate(final InternalConnection connection, final ConnectionDe } } + @Override + void authenticateAsync( + final InternalConnection connection, + final ConnectionDescription connectionDescription, + final SingleResultCallback callback) { + assertFalse(connection.opened()); + String accessToken = getValidCachedAccessToken(); + if (accessToken != null) { + startAsync().run(c -> { + authenticateAsyncUsing(connection, connectionDescription, (bytes) -> prepareTokenAsJwt(accessToken), c); + }).onErrorIf(e -> triggersRetry(e), c -> { + authLockAsync(connection, connectionDescription, c); + }).complete(callback); + } else { + startAsync().run(c -> { + authLockAsync(connection, connectionDescription, c); + }).complete(callback); + } + } + private static boolean triggersRetry(@Nullable final Throwable t) { if (t instanceof MongoSecurityException) { MongoSecurityException e = (MongoSecurityException) t; @@ -218,6 +249,13 @@ private static boolean triggersRetry(@Nullable final Throwable t) { return false; } + private void authenticateAsyncUsing(final InternalConnection connection, + final ConnectionDescription connectionDescription, final Function evaluateChallengeFunction, + final SingleResultCallback callback) { + this.evaluateChallengeFunction = evaluateChallengeFunction; + super.authenticateAsync(connection, connectionDescription, callback); + } + private void authenticateUsing( final InternalConnection connection, final ConnectionDescription connectionDescription, @@ -226,12 +264,12 @@ private void authenticateUsing( super.authenticate(connection, connectionDescription); } - private void authLock(final InternalConnection connection, final ConnectionDescription connectionDescription) { + private void authLock(final InternalConnection connection, final ConnectionDescription description) { fallbackState = FallbackState.INITIAL; Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { while (true) { try { - authenticateUsing(connection, connectionDescription, (challenge) -> evaluate(challenge)); + authenticateUsing(connection, description, (challenge) -> evaluate(challenge)); break; } catch (MongoSecurityException e) { if (!(triggersRetry(e) && shouldRetryHandler())) { @@ -243,6 +281,18 @@ private void authLock(final InternalConnection connection, final ConnectionDescr }); } + private void authLockAsync(final InternalConnection connection, final ConnectionDescription description, + final SingleResultCallback callback) { + + MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); + Locks.withLockAsync( + mongoCredentialWithCache.getOidcLock(), + startAsync().runRetryingWhen( + e -> triggersRetry(e) && shouldRetryHandler(), + c -> authenticateAsyncUsing(connection, description, (challenge) -> evaluate(challenge), c) + ), callback); + } + private byte[] evaluate(final byte[] challenge) { if (isAutomaticAuthentication()) { return prepareAwsTokenFromFileAsJwt(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java index e399b00bea8..d3dd94ae53f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/SaslAuthenticator.java @@ -128,11 +128,9 @@ private void throwIfSaslClientIsNull(@Nullable final SaslClient saslClient) { } private BsonDocument getNextSaslResponse(final SaslClient saslClient, final InternalConnection connection) { - if (!connection.opened()) { - BsonDocument response = getSpeculativeAuthenticateResponse(); - if (response != null) { - return response; - } + BsonDocument response = connection.opened() ? null : getSpeculativeAuthenticateResponse(); + if (response != null) { + return response; } try { @@ -147,7 +145,7 @@ private void getNextSaslResponseAsync(final SaslClient saslClient, final Interna final SingleResultCallback callback) { SingleResultCallback errHandlingCallback = errorHandlingCallback(callback, LOGGER); try { - BsonDocument response = getSpeculativeAuthenticateResponse(); + BsonDocument response = connection.opened() ? null : getSpeculativeAuthenticateResponse(); if (response == null) { byte[] serverResponse = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null); sendSaslStartAsync(serverResponse, connection, (result, t) -> { diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java new file mode 100644 index 00000000000..ea8fadee1aa --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java @@ -0,0 +1,702 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.internal.async; + +import org.junit.jupiter.api.Test; + +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static com.mongodb.internal.async.AsyncRunnable.startAsync; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; + +final class AsyncRunnableTest { + private final AtomicInteger i = new AtomicInteger(); + + @Test + void testRunnableRun() { + /* + In our async code: + 1. a callback is provided + 2. at least one sync method must be converted to async + + To do this: + 1. start an async path using the static method + 2. chain using the appropriate method, which will provide "c" + 3. move all sync code into that method + 4. at the async method, pass in "c" and start a new chained method + 5. complete by invoking the original "callback" at the end of the chain + + Async methods may be preceded by "unaffected" sync code, and this code + will reside above the affected method, as it appears in the sync code. + Below, these "unaffected" methods have no sync/async suffix. + + The return of each chained async method MUST be immediately preceded + by an invocation of the relevant async method using "c". + + Always use a braced lambda body to ensure that the form matches the + corresponding sync code. + */ + assertBehavesSame( + () -> { + multiply(); + incrementSync(); + }, + (callback) -> { + startAsync().run(c -> { + multiply(); + incrementAsync(c); + }).complete(callback); + }); + } + + @Test + void testRunnableRunSyncException() { + // Preceding sync code might throw an exception, so it SHOULD be moved + // into the chain. In any case, any possible exception thrown by it + // MUST be handled by passing it into the callback. + assertBehavesSame( + () -> { + throwException("msg"); + incrementSync(); + }, + (callback) -> { + startAsync().run(c -> { + throwException("msg"); + incrementAsync(c); + }).complete(callback); + }); + + } + + @Test + void testRunnableRunMultiple() { + // Code split across multiple affected methods: + assertBehavesSame( + () -> { + multiply(); + incrementSync(); + multiply(); + incrementSync(); + }, + (callback) -> { + startAsync().run(c -> { + multiply(); + incrementAsync(c); + }).run(c -> { + multiply(); + incrementAsync(c); + }).complete(callback); + }); + } + + @Test + void testRunnableRunMultipleExceptionSkipping() { + // An exception in sync code causes ensuing code to be skipped, and + // split async code behaves in the same way: + assertBehavesSame( + () -> { + throwException("m"); + incrementSync(); + throwException("m2"); + incrementSync(); + }, + (callback) -> { + startAsync().run(c -> { + throwException("m"); + incrementAsync(c); + }).run(c -> { + throwException("m2"); + incrementAsync(c); + }).complete(callback); + }); + } + + @Test + void testRunnableRunMultipleExceptionInAffectedSkipping() { + // Likewise, an exception in the affected method causes a skip: + assertBehavesSame( + () -> { + multiply(); + throwExceptionSync("msg"); + multiply(); + incrementSync(); + }, + (callback) -> { + startAsync().run(c -> { + multiply(); + throwExceptionAsync("msg", c); + }).run(c -> { + multiply(); + incrementAsync(c); + }).complete(callback); + }); + } + + @Test + void testRunnableCompleteRunnable() { + // Sometimes, sync code follows the affected method, and it MUST be + // moved into the final method: + assertBehavesSame( + () -> { + incrementSync(); + multiply(); + }, + (callback) -> { + startAsync().run(c -> { + incrementAsync(c); + }).complete(() -> { + multiply(); + }, callback); + }); + } + + @Test + void testRunnableCompleteRunnableExceptional() { + // ...this makes it easier to correctly handle its exceptions: + assertBehavesSame( + () -> { + incrementSync(); + throwException("m"); + }, + (callback) -> { + startAsync().run(c -> { + incrementAsync(c); + }).complete(() -> { + throwException("m"); + }, callback); + }); + } + + @Test + void testRunnableCompleteRunnableSkippedWhenExceptional() { + // ...and to ensure that it is not executed when it should be skipped: + assertBehavesSame( + () -> { + throwExceptionSync("msg"); + multiply(); + }, + (callback) -> { + startAsync().run(c -> { + throwExceptionAsync("msg", c); + }).complete(() -> { + multiply(); + }, callback); + }); + } + + @Test + void testRunnableCompleteAlways() { + // normal flow + assertBehavesSame( + () -> { + try { + multiply(); + incrementSync(); + } finally { + multiply(); + } + }, + (callback) -> { + startAsync().run(c -> { + multiply(); + incrementAsync(c); + }).completeAlways(() -> { + multiply(); + }, callback); + }); + + } + + @Test + void testRunnableCompleteAlwaysExceptionInAffected() { + // exception in sync/async + assertBehavesSame( + () -> { + try { + multiply(); + throwExceptionSync("msg"); + } finally { + multiply(); + } + }, + (callback) -> { + startAsync().run(c -> { + multiply(); + throwExceptionAsync("msg", c); + }).completeAlways(() -> { + multiply(); + }, callback); + }); + } + + @Test + void testRunnableCompleteAlwaysExceptionInUnaffected() { + // exception in unaffected code + assertBehavesSame( + () -> { + try { + throwException("msg"); + incrementSync(); + } finally { + multiply(); + } + }, + (callback) -> { + startAsync().run(c -> { + throwException("msg"); + incrementAsync(c); + }).completeAlways(() -> { + multiply(); + }, callback); + }); + } + + @Test + void testRunnableCompleteAlwaysExceptionInFinally() { + // exception in finally + assertBehavesSame( + () -> { + try { + multiply(); + incrementSync(); + } finally { + throwException("msg"); + } + }, + (callback) -> { + startAsync().run(c -> { + multiply(); + incrementAsync(c); + }).completeAlways(() -> { + throwException("msg"); + }, callback); + }); + } + + @Test + void testRunnableCompleteAlwaysExceptionInFinallyExceptional() { + // exception in finally, exceptional flow + assertBehavesSame( + () -> { + try { + throwException("first"); + incrementSync(); + } finally { + throwException("msg"); + } + }, + (callback) -> { + startAsync().run(c -> { + throwException("first"); + incrementAsync(c); + }).completeAlways(() -> { + throwException("msg"); + }, callback); + }); + } + + @Test + void testRunnableSupply() { + assertBehavesSame( + () -> { + multiply(); + return valueSync(1); + }, + (callback) -> { + startAsync().supply(c -> { + multiply(); + valueAsync(1, c); + }).complete(callback); + }); + } + + @Test + void testRunnableSupplyExceptional() { + assertBehavesSame( + () -> { + throwException("msg"); + return valueSync(1); + }, + (callback) -> { + startAsync().supply(c -> { + throwException("msg"); + valueAsync(1, c); + }).complete(callback); + }); + } + + @Test + void testRunnableSupplyExceptionalInAffected() { + assertBehavesSame( + () -> { + throwExceptionSync("msg"); + return valueSync(1); + }, + (callback) -> { + startAsync().run(c -> { + throwExceptionAsync("msg", c); + }).supply(c -> { + valueAsync(1, c); + }).complete(callback); + }); + } + + @Test + void testSupplierOnErrorIf() { + // no exception + assertBehavesSame( + () -> { + try { + return valueSync(1); + } catch (Exception e) { + if (e.getMessage().equals("m1")) { + return valueSync(2); + } else { + throw e; + } + } + }, + (SingleResultCallback callback) -> { + startAsync().supply(c -> { + valueAsync(1, c); + }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + valueAsync(2, c); + }).complete(callback); + }); + } + + @Test + void testSupplierOnErrorIfWithValueBranch() { + // exception, with value branch + assertBehavesSame( + () -> { + try { + return throwExceptionSync("m1"); + } catch (Exception e) { + if (e.getMessage().equals("m1")) { + return valueSync(2); + } else { + throw e; + } + } + }, + (callback) -> { + startAsync().supply(c -> { + throwExceptionAsync("m1", c); + }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + valueAsync(2, c); + }).complete(callback); + }); + + } + + @Test + void testSupplierOnErrorIfWithExceptionBranch() { + // exception, with exception branch + assertBehavesSame( + () -> { + try { + return throwExceptionSync("m1"); + } catch (Exception e) { + if (e.getMessage().equals("m1")) { + return this.throwExceptionSync("m2"); + } else { + throw e; + } + } + }, + (callback) -> { + startAsync().supply(c -> { + throwExceptionAsync("m1", c); + }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + throwExceptionAsync("m2", c); + }).complete(callback); + }); + } + + @Test + void testRunnableOnErrorIfNoException() { + // no exception + assertBehavesSame( + () -> { + try { + incrementSync(); + } catch (Exception e) { + if (e.getMessage().equals("m1")) { + multiply(); + incrementSync(); + } else { + throw e; + } + } + }, + (callback) -> { + startAsync().run(c -> { + incrementSync(); + }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + multiply(); + incrementAsync(c); + }).complete(callback); + }); + + } + + @Test + void testRunnableOnErrorIfThrowsMatching() { + // throws matching exception + assertBehavesSame( + () -> { + try { + throwExceptionSync("m1"); + } catch (Exception e) { + if (e.getMessage().equals("m1")) { + multiply(); + incrementSync(); + } else { + throw e; + } + } + }, + (callback) -> { + startAsync().run(c -> { + throwExceptionAsync("m1", c); + }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + multiply(); + incrementAsync(c); + }).complete(callback); + }); + + } + + @Test + void testRunnableOnErrorIfThrowsNonMatching() { + // throws non-matching exception + assertBehavesSame( + () -> { + try { + throwExceptionSync("not-m1"); + } catch (Exception e) { + if (e.getMessage().equals("m1")) { + multiply(); + incrementSync(); + } else { + throw e; + } + } + }, + (callback) -> { + startAsync().run(c -> { + throwExceptionAsync("not-m1", c); + }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + multiply(); + incrementAsync(c); + }).complete(callback); + }); + } + + @Test + void testRunnableOnErrorIfCheckFails() { + // throws but check fails with exception + assertBehavesSame( + () -> { + try { + throwExceptionSync("m1"); + } catch (Exception e) { + if (throwException("check fails")) { + multiply(); + incrementSync(); + } else { + throw e; + } + } + }, + (callback) -> { + startAsync().run(c -> { + throwExceptionAsync("m1", c); + }).onErrorIf(e -> throwException("check fails"), c -> { + multiply(); + incrementAsync(c); + }).complete(callback); + }); + } + + @Test + void testRunnableOnErrorIfSyncBranchfails() { + // throws but sync code in branch fails + assertBehavesSame( + () -> { + try { + throwExceptionSync("m1"); + } catch (Exception e) { + if (e.getMessage().equals("m1")) { + throwException("branch"); + incrementSync(); + } else { + throw e; + } + } + }, + (callback) -> { + startAsync().run(c -> { + throwExceptionAsync("m1", c); + }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + throwException("branch"); + incrementAsync(c); + }).complete(callback); + }); + } + + @Test + void testRunnableOnErrorIfSyncBranchFailsWithMatching() { + // throws but sync code in branch fails with matching exception + assertBehavesSame( + () -> { + try { + throwExceptionSync("m1"); + } catch (Exception e) { + if (e.getMessage().equals("m1")) { + multiply(); + throwException("m1"); + incrementSync(); + } else { + throw e; + } + } + }, + (callback) -> { + startAsync().run(c -> { + throwExceptionAsync("m1", c); + }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + multiply(); + throwException("m1"); + incrementAsync(c); + }).complete(callback); + }); + } + + @Test + void testRunnableOnErrorIfThrowsAndBranchedAffectedMethodThrows() { + // throws, and branch sync/async method throws + assertBehavesSame( + () -> { + try { + throwExceptionSync("m1"); + } catch (Exception e) { + if (e.getMessage().equals("m1")) { + multiply(); + throwExceptionSync("m1"); + } else { + throw e; + } + } + }, + (callback) -> { + startAsync().run(c -> { + throwExceptionAsync("m1", c); + }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + multiply(); + throwExceptionAsync("m1", c); + }).complete(callback); + }); + } + + // unaffected methods: + + private T throwException(final String message) { + throw new RuntimeException(message); + } + + private void multiply() { + i.set(i.get() * 10); + } + + // affected sync-async pairs: + + private void incrementSync() { + i.addAndGet(1); + } + + private void incrementAsync(final SingleResultCallback callback) { + i.addAndGet(1); + callback.onResult(null, null); + } + + private T throwExceptionSync(final String msg) { + throw new RuntimeException(msg); + } + + private void throwExceptionAsync(final String msg, final SingleResultCallback callback) { + try { + throw new RuntimeException(msg); + } catch (Exception e) { + callback.onResult(null, e); + } + } + + private Integer valueSync(final int i) { + return i; + } + + private void valueAsync(final int i, final SingleResultCallback callback) { + callback.onResult(i, null); + } + + private void assertBehavesSame(final Runnable sync, final Consumer> async) { + assertBehavesSame( + () -> { + sync.run(); + return null; + }, + (c) -> { + async.accept((v, e) -> c.onResult(v, e)); + }); + } + + private void assertBehavesSame(final Supplier sync, final Consumer> async) { + AtomicReference actualValue = new AtomicReference<>(); + AtomicReference actualException = new AtomicReference<>(); + try { + i.set(1); + SingleResultCallback callback = (v, e) -> { + actualValue.set(v); + actualException.set(e); + }; + async.accept(callback); + } catch (Exception e) { + fail("async threw an exception instead of using callback"); + } + Integer expectedI = i.get(); + + try { + i.set(1); + T expectedValue = sync.get(); + assertEquals(expectedValue, actualValue.get()); + assertNull(actualException.get()); + } catch (Exception e) { + assertNull(actualValue.get()); + assertNotNull(actualException.get(), "async failed to throw expected: " + e); + assertEquals(e.getClass(), actualException.get().getClass()); + assertEquals(e.getMessage(), actualException.get().getMessage()); + } + assertEquals(expectedI, i.get()); + } +} diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/OidcAuthenticationAsyncProseTests.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/OidcAuthenticationAsyncProseTests.java new file mode 100644 index 00000000000..dd8eee742dc --- /dev/null +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/OidcAuthenticationAsyncProseTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.reactivestreams.client; + +import com.mongodb.MongoClientSettings; +import com.mongodb.client.MongoClient; +import com.mongodb.client.OidcAuthenticationProseTests; +import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; +import org.junit.jupiter.api.Test; +import reactivestreams.helpers.SubscriberHelpers; + +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static util.ThreadTestHelpers.executeAll; + +public class OidcAuthenticationAsyncProseTests extends OidcAuthenticationProseTests { + + @Override + protected MongoClient createMongoClient(final MongoClientSettings settings) { + return new SyncMongoClient(MongoClients.create(settings)); + } + + @Test + public void testNonblockingCallbacks() { + // not a prose spec test + delayNextFind(); + + int simulatedDelayMs = 100; + TestCallback requestCallback = createCallback().setExpired().setDelayMs(simulatedDelayMs); + TestCallback refreshCallback = createCallback().setDelayMs(simulatedDelayMs); + + MongoClientSettings clientSettings = createSettings(OIDC_URL, requestCallback, refreshCallback); + + try (com.mongodb.reactivestreams.client.MongoClient client = MongoClients.create(clientSettings)) { + executeAll(2, () -> { + SubscriberHelpers.OperationSubscriber subscriber = new SubscriberHelpers.OperationSubscriber<>(); + long t1 = System.nanoTime(); + client.getDatabase("test") + .getCollection("test") + .find() + .first() + .subscribe(subscriber); + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - t1); + + assertTrue(elapsedMs < simulatedDelayMs); + subscriber.get(); + }); + + // ensure both callbacks have been tested + assertEquals(1, requestCallback.getInvocations()); + assertEquals(1, refreshCallback.getInvocations()); + } + } +} diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index 74e95d7e253..368e1342e1f 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -598,12 +598,16 @@ public void test6p1ReauthenticationSucceeds() { assertEquals(0, onRefresh.getInvocations()); assertEquals(Arrays.asList( + // speculative: "isMaster started", "isMaster succeeded", + // onRequest: "onRequest invoked", "read access token: test_user1", + // jwt from onRequest: "saslContinue started", "saslContinue succeeded", + // ensuing find: "find started", "find succeeded" ), listener.getEventStrings()); @@ -624,10 +628,12 @@ public void test6p1ReauthenticationSucceeds() { assertEquals(Arrays.asList( "find started", "find failed", + // find has triggered 391, and cleared the access token; fall back to refresh: "onRefresh invoked", "read access token: test_user1", "saslStart started", "saslStart succeeded", + // find retry succeeds: "find started", "find succeeded" ), listener.getEventStrings()); From 901e9d1fb2a55ee68040e9b7baf984edb992cde7 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 5 Jun 2023 10:57:28 -0600 Subject: [PATCH 02/12] Apply suggestions from code review Co-authored-by: Valentin Kovalenko --- driver-core/src/main/com/mongodb/internal/Locks.java | 8 ++------ .../async/function/RetryingAsyncCallbackSupplier.java | 5 +---- .../internal/connection/InternalStreamConnection.java | 8 ++------ .../mongodb/internal/connection/OidcAuthenticator.java | 8 ++------ 4 files changed, 7 insertions(+), 22 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index a76e544e39f..bfde5e5976e 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -42,12 +42,8 @@ public static void withLockAsync(final StampedLock lock, final AsyncRunnable run stamp = lock.writeLockInterruptibly(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); - try { - throw new MongoInterruptedException("Interrupted waiting for lock", e); - } catch (MongoInterruptedException mie) { - callback.onResult(null, mie); - return; - } + callback.onResult(null, new MongoInterruptedException("Interrupted waiting for lock", e)); + return; } runnable.completeAlways(() -> { diff --git a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java index 0375efd539c..92233a072be 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java @@ -88,10 +88,7 @@ public RetryingAsyncCallbackSupplier( final RetryState state, final BiPredicate retryPredicate, final AsyncCallbackSupplier asyncFunction) { - this.state = state; - this.retryPredicate = retryPredicate; - this.failedResultTransformer = (previouslyChosenFailure, lastAttemptFailure) -> lastAttemptFailure; - this.asyncFunction = asyncFunction; + this(state, (previouslyChosenFailure, lastAttemptFailure) -> lastAttemptFailure, retryPredicate, asyncFunction); } @Override diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index c339477c6a2..6ce6814bb94 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -401,12 +401,8 @@ public void sendAndReceiveAsync(final CommandMessage message, final Decoder< AsyncSupplier sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal( message, decoder, sessionContext, requestContext, operationContext, c); - if (!Authenticator.shouldAuthenticate(authenticator, this.description)) { - sendAndReceiveAsyncInternal.complete(callback); - return; - } - - sendAndReceiveAsyncInternal.onErrorIf(e -> triggersReauthentication(e), startAsync() + sendAndReceiveAsyncInternal.onErrorIf(e -> + triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description) , startAsync() .run(c -> { authenticated.set(false); authenticator.reauthenticateAsync(this, c); diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 8496070b660..6b951532428 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -231,9 +231,7 @@ void authenticateAsync( authLockAsync(connection, connectionDescription, c); }).complete(callback); } else { - startAsync().run(c -> { - authLockAsync(connection, connectionDescription, c); - }).complete(callback); + authLockAsync(connection, connectionDescription, callback); } } @@ -284,9 +282,7 @@ private void authLock(final InternalConnection connection, final ConnectionDescr private void authLockAsync(final InternalConnection connection, final ConnectionDescription description, final SingleResultCallback callback) { - MongoCredentialWithCache mongoCredentialWithCache = getMongoCredentialWithCache(); - Locks.withLockAsync( - mongoCredentialWithCache.getOidcLock(), + Locks.withLockAsync(getMongoCredentialWithCache().getOidcLock(), startAsync().runRetryingWhen( e -> triggersRetry(e) && shouldRetryHandler(), c -> authenticateAsyncUsing(connection, description, (challenge) -> evaluate(challenge), c) From 919af419699e75bb7e172b39f5cf95fc959fc130 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Tue, 6 Jun 2023 11:56:49 -0600 Subject: [PATCH 03/12] PR Fixes --- .../com/mongodb/assertions/Assertions.java | 36 ----- .../src/main/com/mongodb/internal/Locks.java | 6 +- .../mongodb/internal/async/AsyncRunnable.java | 77 ++++++---- .../mongodb/internal/async/AsyncSupplier.java | 34 +++-- .../connection/InternalConnection.java | 2 +- .../connection/InternalStreamConnection.java | 30 ++-- .../connection/OidcAuthenticator.java | 30 ++-- .../internal/async/AsyncRunnableTest.java | 136 +++++++++--------- .../OidcAuthenticationAsyncProseTests.java | 4 +- 9 files changed, 178 insertions(+), 177 deletions(-) rename driver-reactive-streams/src/test/functional/com/mongodb/{reactivestreams/client => internal/connection}/OidcAuthenticationAsyncProseTests.java (96%) diff --git a/driver-core/src/main/com/mongodb/assertions/Assertions.java b/driver-core/src/main/com/mongodb/assertions/Assertions.java index fff23119f7c..0ea3b4eb2d7 100644 --- a/driver-core/src/main/com/mongodb/assertions/Assertions.java +++ b/driver-core/src/main/com/mongodb/assertions/Assertions.java @@ -17,7 +17,6 @@ package com.mongodb.assertions; -import com.mongodb.internal.async.SingleResultCallback; import com.mongodb.lang.Nullable; import java.util.Collection; @@ -50,25 +49,6 @@ public static T notNull(final String name, final T value) { return value; } - /** - * Throw IllegalArgumentException if the value is null. - * - * @param name the parameter name - * @param value the value that should not be null - * @param callback the callback that also is passed the exception if the value is null - * @param the value type - * @return the value - * @throws java.lang.IllegalArgumentException if value is null - */ - public static T notNull(final String name, final T value, final SingleResultCallback callback) { - if (value == null) { - IllegalArgumentException exception = new IllegalArgumentException(name + " can not be null"); - callback.onResult(null, exception); - throw exception; - } - return value; - } - /** * Throw IllegalStateException if the condition if false. * @@ -82,22 +62,6 @@ public static void isTrue(final String name, final boolean condition) { } } - /** - * Throw IllegalStateException if the condition if false. - * - * @param name the name of the state that is being checked - * @param condition the condition about the parameter to check - * @param callback the callback that also is passed the exception if the condition is not true - * @throws java.lang.IllegalStateException if the condition is false - */ - public static void isTrue(final String name, final boolean condition, final SingleResultCallback callback) { - if (!condition) { - IllegalStateException exception = new IllegalStateException("state should be: " + name); - callback.onResult(null, exception); - throw exception; - } - } - /** * Throw IllegalArgumentException if the condition if false. * diff --git a/driver-core/src/main/com/mongodb/internal/Locks.java b/driver-core/src/main/com/mongodb/internal/Locks.java index bfde5e5976e..767651e788b 100644 --- a/driver-core/src/main/com/mongodb/internal/Locks.java +++ b/driver-core/src/main/com/mongodb/internal/Locks.java @@ -46,12 +46,12 @@ public static void withLockAsync(final StampedLock lock, final AsyncRunnable run return; } - runnable.completeAlways(() -> { + runnable.thenAlwaysRunAndFinish(() -> { lock.unlockWrite(stamp); }, callback); } - public static V withLock(final StampedLock lock, final Supplier supplier) { + public static void withLock(final StampedLock lock, final Runnable runnable) { long stamp; try { stamp = lock.writeLockInterruptibly(); @@ -60,7 +60,7 @@ public static V withLock(final StampedLock lock, final Supplier supplier) throw new MongoInterruptedException("Interrupted waiting for lock", e); } try { - return supplier.get(); + runnable.run(); } finally { lock.unlockWrite(stamp); } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index 7a845b8bf96..f847442a821 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -19,7 +19,6 @@ import com.mongodb.internal.async.function.RetryState; import com.mongodb.internal.async.function.RetryingAsyncCallbackSupplier; -import java.util.function.Function; import java.util.function.Predicate; /** @@ -27,15 +26,32 @@ */ public interface AsyncRunnable { - static AsyncRunnable startAsync() { + static AsyncRunnable beginAsync() { return (c) -> c.onResult(null, null); } + void runUnsafe(SingleResultCallback callback); // NoResultCallback + /** - * Must be invoked at end of async chain + * Must be invoked at end of async chain. Wraps the lambda in an error + * handler. * @param callback the callback provided by the method the chain is used in */ - void complete(SingleResultCallback callback); // NoResultCallback + default void finish(final SingleResultCallback callback) { + try { + this.runUnsafe((v, e) -> { + try { + callback.onResult(v, e); + } catch (Throwable t) { + throw new CallbackThrew("Unexpected Throwable thrown from callback: ", e); + } + }); + } catch (CallbackThrew t) { + // ignore + } catch (Throwable t) { + callback.onResult(null, t); + } + } /** * Must be invoked at end of async chain @@ -43,8 +59,8 @@ static AsyncRunnable startAsync() { * prior to the callback * @param callback the callback provided by the method the chain is used in */ - default void complete(final Runnable runnable, final SingleResultCallback callback) { - this.complete((r, e) -> { + default void thenRunAndFinish(final Runnable runnable, final SingleResultCallback callback) { + this.finish((r, e) -> { if (e != null) { callback.onResult(null, e); return; @@ -60,13 +76,13 @@ default void complete(final Runnable runnable, final SingleResultCallback } /** - * See {@link #complete(Runnable, SingleResultCallback)}, but the runnable + * See {@link #thenRunAndFinish(Runnable, SingleResultCallback)}, but the runnable * will always be executed, including on the exceptional path. * @param runnable the runnable * @param callback the callback */ - default void completeAlways(final Runnable runnable, final SingleResultCallback callback) { - this.complete((r, e) -> { + default void thenAlwaysRunAndFinish(final Runnable runnable, final SingleResultCallback callback) { + this.finish((r, e) -> { try { runnable.run(); } catch (Throwable t) { @@ -81,15 +97,15 @@ default void completeAlways(final Runnable runnable, final SingleResultCallback< * @param runnable The async runnable to run after this one * @return the composition of this and the runnable */ - default AsyncRunnable run(final AsyncRunnable runnable) { + default AsyncRunnable thenRun(final AsyncRunnable runnable) { return (c) -> { - this.complete((r, e) -> { + this.finish((r, e) -> { if (e != null) { c.onResult(null, e); return; } try { - runnable.complete(c); + runnable.finish(c); } catch (Throwable t) { c.onResult(null, t); } @@ -102,15 +118,15 @@ default AsyncRunnable run(final AsyncRunnable runnable) { * @return the composition of this runnable and the supplier * @param The return type of the supplier */ - default AsyncSupplier supply(final AsyncSupplier supplier) { + default AsyncSupplier thenSupply(final AsyncSupplier supplier) { return (c) -> { - this.complete((r, e) -> { + this.finish((r, e) -> { if (e != null) { c.onResult(null, e); return; } try { - supplier.complete(c); + supplier.finish(c); } catch (Throwable t) { c.onResult(null, t); } @@ -123,18 +139,18 @@ default AsyncSupplier supply(final AsyncSupplier supplier) { * @param runnable The branch to execute if the error matches * @return The composition of this, and the conditional branch */ - default AsyncRunnable onErrorIf( - final Function errorCheck, + default AsyncRunnable onErrorRunIf( + final Predicate errorCheck, final AsyncRunnable runnable) { - return (callback) -> this.complete((r, e) -> { + return (callback) -> this.finish((r, e) -> { if (e == null) { callback.onResult(r, null); return; } try { - Boolean check = errorCheck.apply(e); + boolean check = errorCheck.test(e); if (check) { - runnable.complete(callback); + runnable.finish(callback); return; } } catch (Throwable t) { @@ -146,20 +162,27 @@ default AsyncRunnable onErrorIf( } /** - * @see RetryingAsyncCallbackSupplier + * @param runnable the runnable to loop * @param shouldRetry condition under which to retry - * @param runnable the runnable to loop * @return the composition of this, and the looping branch + * @see RetryingAsyncCallbackSupplier */ - default AsyncRunnable runRetryingWhen( - final Predicate shouldRetry, - final AsyncRunnable runnable) { - return this.run(callback -> { + default AsyncRunnable thenRunRetryingWhile( + final AsyncRunnable runnable, final Predicate shouldRetry) { + return this.thenRun(callback -> { new RetryingAsyncCallbackSupplier( new RetryState(), (rs, lastAttemptFailure) -> shouldRetry.test(lastAttemptFailure), - cb -> runnable.complete(cb) + cb -> runnable.finish(cb) ).get(callback); }); } + + final class CallbackThrew extends AssertionError { + private static final long serialVersionUID = 875624357420415700L; + + public CallbackThrew(final String s, final Throwable e) { + super(s, e); + } + } } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index 236dfa1905c..3b7351723c9 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -16,38 +16,56 @@ package com.mongodb.internal.async; -import java.util.function.Function; +import java.util.function.Predicate; + +import static com.mongodb.internal.async.AsyncRunnable.CallbackThrew; /** * See AsyncRunnableTest for usage */ public interface AsyncSupplier { + void supplyUnsafe(SingleResultCallback callback); + /** * Must be invoked at end of async chain * @param callback the callback provided by the method the chain is used in */ - void complete(SingleResultCallback callback); + default void finish(final SingleResultCallback callback) { + try { + this.supplyUnsafe((v, e) -> { + try { + callback.onResult(v, e); + } catch (Throwable t) { + throw new CallbackThrew("Unexpected Throwable thrown from callback: ", e); + } + }); + } catch (CallbackThrew t) { + // ignore + } catch (Throwable t) { + callback.onResult(null, t); + } + } /** - * @see AsyncRunnable#onErrorIf(Function, AsyncRunnable). + * @see AsyncRunnable#onErrorRunIf(Predicate, AsyncRunnable). * * @param errorCheck A check, comparable to a catch-if/otherwise-rethrow * @param supplier The branch to execute if the error matches * @return The composition of this, and the conditional branch */ - default AsyncSupplier onErrorIf( - final Function errorCheck, + default AsyncSupplier onErrorSupplyIf( + final Predicate errorCheck, final AsyncSupplier supplier) { - return (callback) -> this.complete((r, e) -> { + return (callback) -> this.finish((r, e) -> { if (e == null) { callback.onResult(r, null); return; } try { - Boolean check = errorCheck.apply(e); + boolean check = errorCheck.test(e); if (check) { - supplier.complete(callback); + supplier.finish(callback); return; } } catch (Throwable t) { diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java index 59e34404b1f..66ff3e51c16 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalConnection.java @@ -50,7 +50,7 @@ public interface InternalConnection extends BufferProvider { ServerDescription getInitialServerDescription(); /** - * Opens the connection so its ready for use + * Opens the connection so its ready for use. Will perform a handshake. */ void open(); diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 6ce6814bb94..adca7b688c8 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -73,9 +73,10 @@ import java.util.function.Supplier; import static com.mongodb.assertions.Assertions.assertNotNull; +import static com.mongodb.assertions.Assertions.assertNull; import static com.mongodb.assertions.Assertions.isTrue; import static com.mongodb.assertions.Assertions.notNull; -import static com.mongodb.internal.async.AsyncRunnable.startAsync; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback; import static com.mongodb.internal.connection.CommandHelper.HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO; @@ -226,7 +227,7 @@ public int getGeneration() { @Override public void open() { - isTrue("Open already called", stream == null); + assertNull(stream); stream = streamFactory.create(getServerAddressWithResolver()); try { stream.open(); @@ -248,7 +249,7 @@ public void open() { @Override public void openAsync(final SingleResultCallback callback) { - isTrue("Open already called", stream == null, callback); + assertNull(stream); try { stream = streamFactory.create(getServerAddressWithResolver()); stream.openAsync(new AsyncCompletionHandler() { @@ -396,21 +397,19 @@ public T sendAndReceive(final CommandMessage message, final Decoder decod @Override public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); AsyncSupplier sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal( message, decoder, sessionContext, requestContext, operationContext, c); - - sendAndReceiveAsyncInternal.onErrorIf(e -> - triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description) , startAsync() - .run(c -> { + sendAndReceiveAsyncInternal.onErrorSupplyIf(e -> + triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description), beginAsync() + .thenRun(c -> { authenticated.set(false); authenticator.reauthenticateAsync(this, c); - }).supply((c) -> { + }).thenSupply((c) -> { authenticated.set(true); - sendAndReceiveAsyncInternal.complete(c); + sendAndReceiveAsyncInternal.finish(c); })) - .complete(callback); + .finish(callback); } public static boolean triggersReauthentication(@Nullable final Throwable t) { @@ -657,7 +656,7 @@ public void sendMessage(final List byteBuffers, final int lastRequestId @Override public ResponseBuffers receiveMessage(final int responseTo) { - notNull("stream is open", stream); + assertNotNull(stream); if (isClosed()) { throw new MongoSocketClosedException("Cannot read from a closed stream", getServerAddress()); } @@ -675,8 +674,9 @@ private ResponseBuffers receiveMessageWithAdditionalTimeout(final int additional } @Override - public void sendMessageAsync(final List byteBuffers, final int lastRequestId, final SingleResultCallback callback) { - notNull("stream is open", stream, callback); + public void sendMessageAsync(final List byteBuffers, final int lastRequestId, + final SingleResultCallback callback) { + assertNotNull(stream); if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); @@ -708,7 +708,7 @@ public void failed(final Throwable t) { @Override public void receiveMessageAsync(final int responseTo, final SingleResultCallback callback) { - isTrue("stream is open", stream != null, callback); + assertNotNull(stream); if (isClosed()) { callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress())); diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 6b951532428..b5b29e20f9f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -67,7 +67,7 @@ import static com.mongodb.assertions.Assertions.assertFalse; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.assertions.Assertions.assertTrue; -import static com.mongodb.internal.async.AsyncRunnable.startAsync; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.connection.OidcAuthenticator.OidcValidator.validateBeforeUse; import static java.lang.String.format; @@ -182,7 +182,6 @@ private OidcRequestCallback getRequestCallback() { @Override public void reauthenticate(final InternalConnection connection) { - // method must only be called after original handshake: assertTrue(connection.opened()); authLock(connection, connection.getDescription()); } @@ -190,17 +189,14 @@ public void reauthenticate(final InternalConnection connection) { @Override public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { assertTrue(connection.opened()); - fallbackState = FallbackState.INITIAL; - startAsync().run(c -> { + beginAsync().thenRun(c -> { authLockAsync(connection, connection.getDescription(), c); - }).complete(callback); + }).finish(callback); } @Override public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) { - // method must only be called during original handshake: assertFalse(connection.opened()); - // this method "wraps" the default authentication method in custom OIDC retry logic String accessToken = getValidCachedAccessToken(); if (accessToken != null) { try { @@ -225,11 +221,11 @@ void authenticateAsync( assertFalse(connection.opened()); String accessToken = getValidCachedAccessToken(); if (accessToken != null) { - startAsync().run(c -> { + beginAsync().thenRun(c -> { authenticateAsyncUsing(connection, connectionDescription, (bytes) -> prepareTokenAsJwt(accessToken), c); - }).onErrorIf(e -> triggersRetry(e), c -> { + }).onErrorRunIf(e -> triggersRetry(e), c -> { authLockAsync(connection, connectionDescription, c); - }).complete(callback); + }).finish(callback); } else { authLockAsync(connection, connectionDescription, callback); } @@ -270,22 +266,22 @@ private void authLock(final InternalConnection connection, final ConnectionDescr authenticateUsing(connection, description, (challenge) -> evaluate(challenge)); break; } catch (MongoSecurityException e) { - if (!(triggersRetry(e) && shouldRetryHandler())) { - throw e; + if (triggersRetry(e) && shouldRetryHandler()) { + continue; } + throw e; } } - return null; }); } private void authLockAsync(final InternalConnection connection, final ConnectionDescription description, final SingleResultCallback callback) { - + fallbackState = FallbackState.INITIAL; Locks.withLockAsync(getMongoCredentialWithCache().getOidcLock(), - startAsync().runRetryingWhen( - e -> triggersRetry(e) && shouldRetryHandler(), - c -> authenticateAsyncUsing(connection, description, (challenge) -> evaluate(challenge), c) + beginAsync().thenRunRetryingWhile( + c -> authenticateAsyncUsing(connection, description, (challenge) -> evaluate(challenge), c), + e -> triggersRetry(e) && shouldRetryHandler() ), callback); } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java index ea8fadee1aa..96e78d499cc 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java @@ -22,7 +22,7 @@ import java.util.function.Consumer; import java.util.function.Supplier; -import static com.mongodb.internal.async.AsyncRunnable.startAsync; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; @@ -39,7 +39,7 @@ void testRunnableRun() { 2. at least one sync method must be converted to async To do this: - 1. start an async path using the static method + 1. start an async chain using the static method 2. chain using the appropriate method, which will provide "c" 3. move all sync code into that method 4. at the async method, pass in "c" and start a new chained method @@ -61,10 +61,10 @@ void testRunnableRun() { incrementSync(); }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { multiply(); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -79,10 +79,10 @@ void testRunnableRunSyncException() { incrementSync(); }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwException("msg"); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -98,13 +98,13 @@ void testRunnableRunMultiple() { incrementSync(); }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { multiply(); incrementAsync(c); - }).run(c -> { + }).thenRun(c -> { multiply(); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -120,13 +120,13 @@ void testRunnableRunMultipleExceptionSkipping() { incrementSync(); }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwException("m"); incrementAsync(c); - }).run(c -> { + }).thenRun(c -> { throwException("m2"); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -141,13 +141,13 @@ void testRunnableRunMultipleExceptionInAffectedSkipping() { incrementSync(); }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { multiply(); throwExceptionAsync("msg", c); - }).run(c -> { + }).thenRun(c -> { multiply(); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -161,9 +161,9 @@ void testRunnableCompleteRunnable() { multiply(); }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { incrementAsync(c); - }).complete(() -> { + }).thenRunAndFinish(() -> { multiply(); }, callback); }); @@ -178,9 +178,9 @@ void testRunnableCompleteRunnableExceptional() { throwException("m"); }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { incrementAsync(c); - }).complete(() -> { + }).thenRunAndFinish(() -> { throwException("m"); }, callback); }); @@ -195,9 +195,9 @@ void testRunnableCompleteRunnableSkippedWhenExceptional() { multiply(); }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwExceptionAsync("msg", c); - }).complete(() -> { + }).thenRunAndFinish(() -> { multiply(); }, callback); }); @@ -216,10 +216,10 @@ void testRunnableCompleteAlways() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { multiply(); incrementAsync(c); - }).completeAlways(() -> { + }).thenAlwaysRunAndFinish(() -> { multiply(); }, callback); }); @@ -239,10 +239,10 @@ void testRunnableCompleteAlwaysExceptionInAffected() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { multiply(); throwExceptionAsync("msg", c); - }).completeAlways(() -> { + }).thenAlwaysRunAndFinish(() -> { multiply(); }, callback); }); @@ -261,10 +261,10 @@ void testRunnableCompleteAlwaysExceptionInUnaffected() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwException("msg"); incrementAsync(c); - }).completeAlways(() -> { + }).thenAlwaysRunAndFinish(() -> { multiply(); }, callback); }); @@ -283,10 +283,10 @@ void testRunnableCompleteAlwaysExceptionInFinally() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { multiply(); incrementAsync(c); - }).completeAlways(() -> { + }).thenAlwaysRunAndFinish(() -> { throwException("msg"); }, callback); }); @@ -305,10 +305,10 @@ void testRunnableCompleteAlwaysExceptionInFinallyExceptional() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwException("first"); incrementAsync(c); - }).completeAlways(() -> { + }).thenAlwaysRunAndFinish(() -> { throwException("msg"); }, callback); }); @@ -322,10 +322,10 @@ void testRunnableSupply() { return valueSync(1); }, (callback) -> { - startAsync().supply(c -> { + beginAsync().thenSupply(c -> { multiply(); valueAsync(1, c); - }).complete(callback); + }).finish(callback); }); } @@ -337,10 +337,10 @@ void testRunnableSupplyExceptional() { return valueSync(1); }, (callback) -> { - startAsync().supply(c -> { + beginAsync().thenSupply(c -> { throwException("msg"); valueAsync(1, c); - }).complete(callback); + }).finish(callback); }); } @@ -352,11 +352,11 @@ void testRunnableSupplyExceptionalInAffected() { return valueSync(1); }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwExceptionAsync("msg", c); - }).supply(c -> { + }).thenSupply(c -> { valueAsync(1, c); - }).complete(callback); + }).finish(callback); }); } @@ -376,11 +376,11 @@ void testSupplierOnErrorIf() { } }, (SingleResultCallback callback) -> { - startAsync().supply(c -> { + beginAsync().thenSupply(c -> { valueAsync(1, c); - }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + }).onErrorSupplyIf(e -> e.getMessage().equals("m1"), c -> { valueAsync(2, c); - }).complete(callback); + }).finish(callback); }); } @@ -400,11 +400,11 @@ void testSupplierOnErrorIfWithValueBranch() { } }, (callback) -> { - startAsync().supply(c -> { + beginAsync().thenSupply(c -> { throwExceptionAsync("m1", c); - }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + }).onErrorSupplyIf(e -> e.getMessage().equals("m1"), c -> { valueAsync(2, c); - }).complete(callback); + }).finish(callback); }); } @@ -425,11 +425,11 @@ void testSupplierOnErrorIfWithExceptionBranch() { } }, (callback) -> { - startAsync().supply(c -> { + beginAsync().thenSupply(c -> { throwExceptionAsync("m1", c); - }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + }).onErrorSupplyIf(e -> e.getMessage().equals("m1"), c -> { throwExceptionAsync("m2", c); - }).complete(callback); + }).finish(callback); }); } @@ -450,12 +450,12 @@ void testRunnableOnErrorIfNoException() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { incrementSync(); - }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { multiply(); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -477,12 +477,12 @@ void testRunnableOnErrorIfThrowsMatching() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwExceptionAsync("m1", c); - }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { multiply(); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -504,12 +504,12 @@ void testRunnableOnErrorIfThrowsNonMatching() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwExceptionAsync("not-m1", c); - }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { multiply(); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -530,12 +530,12 @@ void testRunnableOnErrorIfCheckFails() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwExceptionAsync("m1", c); - }).onErrorIf(e -> throwException("check fails"), c -> { + }).onErrorRunIf(e -> throwException("check fails"), c -> { multiply(); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -556,12 +556,12 @@ void testRunnableOnErrorIfSyncBranchfails() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwExceptionAsync("m1", c); - }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { throwException("branch"); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -583,13 +583,13 @@ void testRunnableOnErrorIfSyncBranchFailsWithMatching() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwExceptionAsync("m1", c); - }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { multiply(); throwException("m1"); incrementAsync(c); - }).complete(callback); + }).finish(callback); }); } @@ -610,12 +610,12 @@ void testRunnableOnErrorIfThrowsAndBranchedAffectedMethodThrows() { } }, (callback) -> { - startAsync().run(c -> { + beginAsync().thenRun(c -> { throwExceptionAsync("m1", c); - }).onErrorIf(e -> e.getMessage().equals("m1"), c -> { + }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { multiply(); throwExceptionAsync("m1", c); - }).complete(callback); + }).finish(callback); }); } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/OidcAuthenticationAsyncProseTests.java b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java similarity index 96% rename from driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/OidcAuthenticationAsyncProseTests.java rename to driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java index dd8eee742dc..b18825e89a8 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/OidcAuthenticationAsyncProseTests.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationAsyncProseTests.java @@ -14,11 +14,11 @@ * limitations under the License. */ -package com.mongodb.reactivestreams.client; +package com.mongodb.internal.connection; import com.mongodb.MongoClientSettings; import com.mongodb.client.MongoClient; -import com.mongodb.client.OidcAuthenticationProseTests; +import com.mongodb.reactivestreams.client.MongoClients; import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient; import org.junit.jupiter.api.Test; import reactivestreams.helpers.SubscriberHelpers; From a81e5aa2f978f1f0454b0fcd204a0f69bbc6f90c Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 12 Jun 2023 15:07:03 -0600 Subject: [PATCH 04/12] PR fixes --- .../main/com/mongodb/internal/connection/Authenticator.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java index 0c002dae9fc..232eeb45049 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/Authenticator.java @@ -27,6 +27,7 @@ import com.mongodb.lang.Nullable; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; /** *

This class is not part of the public API and may be removed or changed at any time

@@ -105,8 +106,9 @@ public void reauthenticate(final InternalConnection connection) { } public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { - throw new UnsupportedOperationException( - "Reauthentication requested by server but is not supported by specified mechanism."); + beginAsync().thenRun((c) -> { + authenticateAsync(connection, connection.getDescription(), c); + }).finish(callback); } } From 53b4f67d73670188708c1b74dabf99756f132504 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Tue, 13 Jun 2023 15:14:01 -0600 Subject: [PATCH 05/12] Apply suggestions from code review Co-authored-by: Valentin Kovalenko --- .../com/mongodb/internal/async/AsyncRunnable.java | 15 +++++---------- .../com/mongodb/internal/async/AsyncSupplier.java | 14 ++++++++------ .../internal/connection/OidcAuthenticator.java | 3 ++- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index f847442a821..db2f4dd4c02 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -86,6 +86,9 @@ default void thenAlwaysRunAndFinish(final Runnable runnable, final SingleResultC try { runnable.run(); } catch (Throwable t) { + if (e != null) { + t.addSuppressed(e); + } callback.onResult(null, t); return; } @@ -104,11 +107,7 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) { c.onResult(null, e); return; } - try { - runnable.finish(c); - } catch (Throwable t) { - c.onResult(null, t); - } + runnable.finish(c); }); }; } @@ -125,11 +124,7 @@ default AsyncSupplier thenSupply(final AsyncSupplier supplier) { c.onResult(null, e); return; } - try { - supplier.finish(c); - } catch (Throwable t) { - c.onResult(null, t); - } + supplier.finish(c); }); }; } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index 3b7351723c9..01b89d97da2 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -62,17 +62,19 @@ default AsyncSupplier onErrorSupplyIf( callback.onResult(r, null); return; } + boolean useProvidedSupplier; try { - boolean check = errorCheck.test(e); - if (check) { - supplier.finish(callback); - return; - } + useProvidedSupplier = errorCheck.test(e); } catch (Throwable t) { + t.addSuppressed(e); callback.onResult(null, t); return; } - callback.onResult(r, e); + if (useProvidedSupplier) { + supplier.finish(callback); + return; + } + callback.onResult(null, e); }); } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index b5b29e20f9f..78755f1b5e3 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -190,6 +190,7 @@ public void reauthenticate(final InternalConnection connection) { public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { assertTrue(connection.opened()); beginAsync().thenRun(c -> { + assertTrue(connection.opened()); authLockAsync(connection, connection.getDescription(), c); }).finish(callback); } @@ -222,7 +223,7 @@ void authenticateAsync( String accessToken = getValidCachedAccessToken(); if (accessToken != null) { beginAsync().thenRun(c -> { - authenticateAsyncUsing(connection, connectionDescription, (bytes) -> prepareTokenAsJwt(accessToken), c); + authenticateAsyncUsing(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken), c); }).onErrorRunIf(e -> triggersRetry(e), c -> { authLockAsync(connection, connectionDescription, c); }).finish(callback); From 6cc18af7b4941a0892c7d4af8418b9a664c1d54b Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Tue, 13 Jun 2023 15:14:18 -0600 Subject: [PATCH 06/12] Fixes --- .../src/main/com/mongodb/internal/async/AsyncRunnable.java | 2 +- .../src/main/com/mongodb/internal/async/AsyncSupplier.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index db2f4dd4c02..19e664b8c8d 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -43,7 +43,7 @@ default void finish(final SingleResultCallback callback) { try { callback.onResult(v, e); } catch (Throwable t) { - throw new CallbackThrew("Unexpected Throwable thrown from callback: ", e); + throw new CallbackThrew("Unexpected Throwable thrown from callback: ", t); } }); } catch (CallbackThrew t) { diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index 01b89d97da2..b2d694e3131 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -37,7 +37,7 @@ default void finish(final SingleResultCallback callback) { try { callback.onResult(v, e); } catch (Throwable t) { - throw new CallbackThrew("Unexpected Throwable thrown from callback: ", e); + throw new CallbackThrew("Unexpected Throwable thrown from callback: ", t); } }); } catch (CallbackThrew t) { From 2ec4729c910920fd50cd7e3fe0c3c65c5c533a55 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Thu, 15 Jun 2023 12:27:15 -0600 Subject: [PATCH 07/12] PR fixes --- .../mongodb/internal/async/AsyncRunnable.java | 58 +++++++++---------- .../mongodb/internal/async/AsyncSupplier.java | 28 +++++---- .../internal/async/AsyncRunnableTest.java | 6 +- 3 files changed, 42 insertions(+), 50 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index 19e664b8c8d..8e5a35f3dcf 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -30,7 +30,7 @@ static AsyncRunnable beginAsync() { return (c) -> c.onResult(null, null); } - void runUnsafe(SingleResultCallback callback); // NoResultCallback + void runInternal(SingleResultCallback callback); // NoResultCallback /** * Must be invoked at end of async chain. Wraps the lambda in an error @@ -38,18 +38,18 @@ static AsyncRunnable beginAsync() { * @param callback the callback provided by the method the chain is used in */ default void finish(final SingleResultCallback callback) { + final boolean[] callbackInvoked = {false}; try { - this.runUnsafe((v, e) -> { - try { - callback.onResult(v, e); - } catch (Throwable t) { - throw new CallbackThrew("Unexpected Throwable thrown from callback: ", t); - } + this.runInternal((v, e) -> { + callbackInvoked[0] = true; + callback.onResult(v, e); }); - } catch (CallbackThrew t) { - // ignore } catch (Throwable t) { - callback.onResult(null, t); + if (callbackInvoked[0]) { + throw t; + } else { + callback.onResult(null, t); + } } } @@ -102,12 +102,12 @@ default void thenAlwaysRunAndFinish(final Runnable runnable, final SingleResultC */ default AsyncRunnable thenRun(final AsyncRunnable runnable) { return (c) -> { - this.finish((r, e) -> { - if (e != null) { + this.runInternal((r, e) -> { + if (e == null) { + runnable.runInternal(c); + } else { c.onResult(null, e); - return; } - runnable.finish(c); }); }; } @@ -119,12 +119,12 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) { */ default AsyncSupplier thenSupply(final AsyncSupplier supplier) { return (c) -> { - this.finish((r, e) -> { - if (e != null) { + this.runInternal((r, e) -> { + if (e == null) { + supplier.supplyInternal(c); + } else { c.onResult(null, e); - return; } - supplier.finish(c); }); }; } @@ -137,21 +137,23 @@ default AsyncSupplier thenSupply(final AsyncSupplier supplier) { default AsyncRunnable onErrorRunIf( final Predicate errorCheck, final AsyncRunnable runnable) { - return (callback) -> this.finish((r, e) -> { + return (callback) -> this.runInternal((r, e) -> { if (e == null) { callback.onResult(r, null); return; } + boolean errorMatched; try { - boolean check = errorCheck.test(e); - if (check) { - runnable.finish(callback); - return; - } + errorMatched = errorCheck.test(e); } catch (Throwable t) { + t.addSuppressed(e); callback.onResult(null, t); return; } + if (errorMatched) { + runnable.runInternal(callback); + return; + } callback.onResult(r, e); }); } @@ -172,12 +174,4 @@ default AsyncRunnable thenRunRetryingWhile( ).get(callback); }); } - - final class CallbackThrew extends AssertionError { - private static final long serialVersionUID = 875624357420415700L; - - public CallbackThrew(final String s, final Throwable e) { - super(s, e); - } - } } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index b2d694e3131..1f2e2ad8152 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -18,32 +18,30 @@ import java.util.function.Predicate; -import static com.mongodb.internal.async.AsyncRunnable.CallbackThrew; - /** * See AsyncRunnableTest for usage */ public interface AsyncSupplier { - void supplyUnsafe(SingleResultCallback callback); + void supplyInternal(SingleResultCallback callback); /** * Must be invoked at end of async chain * @param callback the callback provided by the method the chain is used in */ default void finish(final SingleResultCallback callback) { + final boolean[] callbackInvoked = {false}; try { - this.supplyUnsafe((v, e) -> { - try { - callback.onResult(v, e); - } catch (Throwable t) { - throw new CallbackThrew("Unexpected Throwable thrown from callback: ", t); - } + this.supplyInternal((v, e) -> { + callbackInvoked[0] = true; + callback.onResult(v, e); }); - } catch (CallbackThrew t) { - // ignore } catch (Throwable t) { - callback.onResult(null, t); + if (callbackInvoked[0]) { + throw t; + } else { + callback.onResult(null, t); + } } } @@ -62,15 +60,15 @@ default AsyncSupplier onErrorSupplyIf( callback.onResult(r, null); return; } - boolean useProvidedSupplier; + boolean errorMatched; try { - useProvidedSupplier = errorCheck.test(e); + errorMatched = errorCheck.test(e); } catch (Throwable t) { t.addSuppressed(e); callback.onResult(null, t); return; } - if (useProvidedSupplier) { + if (errorMatched) { supplier.finish(callback); return; } diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java index 96e78d499cc..8cefe4de6cb 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java @@ -681,8 +681,8 @@ private void assertBehavesSame(final Supplier sync, final Consumer void assertBehavesSame(final Supplier sync, final Consumer Date: Fri, 16 Jun 2023 11:06:44 -0600 Subject: [PATCH 08/12] Add full test variations, additional chaining methods, refactor --- .../org/bson/codecs/kotlin/DataClassCodec.kt | 4 +- .../codecs/kotlin/DataClassCodecProvider.kt | 5 +- .../kotlin/DataClassCodecProviderTest.kt | 8 +- .../mongodb/internal/async/AsyncConsumer.java | 23 + .../mongodb/internal/async/AsyncFunction.java | 26 + .../mongodb/internal/async/AsyncRunnable.java | 108 ++- .../mongodb/internal/async/AsyncSupplier.java | 53 +- .../connection/InternalStreamConnection.java | 7 +- .../connection/OidcAuthenticator.java | 2 +- .../internal/async/AsyncFunctionsTest.java | 586 +++++++++++++++ .../internal/async/AsyncRunnableTest.java | 702 ------------------ 11 files changed, 734 insertions(+), 790 deletions(-) create mode 100644 driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java create mode 100644 driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java create mode 100644 driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java delete mode 100644 driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java diff --git a/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodec.kt b/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodec.kt index da84033c521..2b67e1de0c3 100644 --- a/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodec.kt +++ b/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodec.kt @@ -139,8 +139,8 @@ internal data class DataClassCodec( validateAnnotations(kClass) val primaryConstructor = kClass.primaryConstructor ?: throw CodecConfigurationException("No primary constructor for $kClass") - val typeMap = types.mapIndexed { i, k -> primaryConstructor.typeParameters[i].createType() to k } - .toMap() + val typeMap = + types.mapIndexed { i, k -> primaryConstructor.typeParameters[i].createType() to k }.toMap() val propertyModels = primaryConstructor.parameters.map { kParameter -> diff --git a/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodecProvider.kt b/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodecProvider.kt index e13d74116c7..962741033e1 100644 --- a/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodecProvider.kt +++ b/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodecProvider.kt @@ -15,15 +15,14 @@ */ package org.bson.codecs.kotlin +import java.lang.reflect.Type import org.bson.codecs.Codec import org.bson.codecs.configuration.CodecProvider import org.bson.codecs.configuration.CodecRegistry -import java.lang.reflect.Type /** A Kotlin reflection based Codec Provider for data classes */ public class DataClassCodecProvider : CodecProvider { - override fun get(clazz: Class, registry: CodecRegistry): Codec? = - get(clazz, emptyList(), registry) + override fun get(clazz: Class, registry: CodecRegistry): Codec? = get(clazz, emptyList(), registry) override fun get(clazz: Class, typeArguments: List, registry: CodecRegistry): Codec? = DataClassCodec.create(clazz.kotlin, registry, typeArguments) diff --git a/bson-kotlin/src/test/kotlin/org/bson/codecs/kotlin/DataClassCodecProviderTest.kt b/bson-kotlin/src/test/kotlin/org/bson/codecs/kotlin/DataClassCodecProviderTest.kt index b36ada9622a..e0c7f9d1d1b 100644 --- a/bson-kotlin/src/test/kotlin/org/bson/codecs/kotlin/DataClassCodecProviderTest.kt +++ b/bson-kotlin/src/test/kotlin/org/bson/codecs/kotlin/DataClassCodecProviderTest.kt @@ -16,16 +16,16 @@ package org.bson.codecs.kotlin import com.mongodb.MongoClientSettings +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue import org.bson.codecs.configuration.CodecConfigurationException import org.bson.codecs.kotlin.samples.DataClassParameterized import org.bson.codecs.kotlin.samples.DataClassWithSimpleValues import org.bson.conversions.Bson import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertNull -import kotlin.test.assertTrue class DataClassCodecProviderTest { diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java new file mode 100644 index 00000000000..f3d54e172fa --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java @@ -0,0 +1,23 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +/** + * @see AsyncRunnable + */ +public interface AsyncConsumer extends AsyncFunction { +} diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java new file mode 100644 index 00000000000..0776edf815a --- /dev/null +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java @@ -0,0 +1,26 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.async; + +import com.mongodb.lang.Nullable; + +/** + * @see AsyncRunnable + */ +public interface AsyncFunction { + void internal(@Nullable T value, SingleResultCallback callback); +} diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index 8e5a35f3dcf..9b5b1158312 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -20,39 +20,17 @@ import com.mongodb.internal.async.function.RetryingAsyncCallbackSupplier; import java.util.function.Predicate; +import java.util.function.Supplier; /** - * See AsyncRunnableTest for usage + * See tests for usage (AsyncFunctionsTest). */ -public interface AsyncRunnable { +public interface AsyncRunnable extends AsyncSupplier { static AsyncRunnable beginAsync() { return (c) -> c.onResult(null, null); } - void runInternal(SingleResultCallback callback); // NoResultCallback - - /** - * Must be invoked at end of async chain. Wraps the lambda in an error - * handler. - * @param callback the callback provided by the method the chain is used in - */ - default void finish(final SingleResultCallback callback) { - final boolean[] callbackInvoked = {false}; - try { - this.runInternal((v, e) -> { - callbackInvoked[0] = true; - callback.onResult(v, e); - }); - } catch (Throwable t) { - if (callbackInvoked[0]) { - throw t; - } else { - callback.onResult(null, t); - } - } - } - /** * Must be invoked at end of async chain * @param runnable the sync code to invoke (under non-exceptional flow) @@ -97,14 +75,14 @@ default void thenAlwaysRunAndFinish(final Runnable runnable, final SingleResultC } /** - * @param runnable The async runnable to run after this one - * @return the composition of this and the runnable + * @param runnable The async runnable to run after this runnable + * @return the composition of this runnable and the runnable, a runnable */ default AsyncRunnable thenRun(final AsyncRunnable runnable) { return (c) -> { - this.runInternal((r, e) -> { + this.internal((r, e) -> { if (e == null) { - runnable.runInternal(c); + runnable.internal(c); } else { c.onResult(null, e); } @@ -113,49 +91,49 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) { } /** - * @param supplier The supplier to supply using after this runnable. - * @return the composition of this runnable and the supplier - * @param The return type of the supplier + * @param condition the condition to check + * @param runnable The async runnable to run after this runnable, + * if and only if the condition is met + * @return the composition of this runnable and the runnable, a runnable */ - default AsyncSupplier thenSupply(final AsyncSupplier supplier) { - return (c) -> { - this.runInternal((r, e) -> { - if (e == null) { - supplier.supplyInternal(c); + default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRunnable runnable) { + return (callback) -> { + this.internal((r, e) -> { + if (e != null) { + callback.onResult(null, e); + return; + } + boolean matched; + try { + matched = condition.get(); + } catch (Throwable t) { + callback.onResult(null, t); + return; + } + if (matched) { + runnable.internal(callback); } else { - c.onResult(null, e); + callback.onResult(null, null); } }); }; } /** - * @param errorCheck A check, comparable to a catch-if/otherwise-rethrow - * @param runnable The branch to execute if the error matches - * @return The composition of this, and the conditional branch + * @param supplier The supplier to supply using after this runnable + * @return the composition of this runnable and the supplier, a supplier + * @param The return type of the resulting supplier */ - default AsyncRunnable onErrorRunIf( - final Predicate errorCheck, - final AsyncRunnable runnable) { - return (callback) -> this.runInternal((r, e) -> { - if (e == null) { - callback.onResult(r, null); - return; - } - boolean errorMatched; - try { - errorMatched = errorCheck.test(e); - } catch (Throwable t) { - t.addSuppressed(e); - callback.onResult(null, t); - return; - } - if (errorMatched) { - runnable.runInternal(callback); - return; - } - callback.onResult(r, e); - }); + default AsyncSupplier thenSupply(final AsyncSupplier supplier) { + return (c) -> { + this.internal((r, e) -> { + if (e == null) { + supplier.internal(c); + } else { + c.onResult(null, e); + } + }); + }; } /** @@ -166,11 +144,11 @@ default AsyncRunnable onErrorRunIf( */ default AsyncRunnable thenRunRetryingWhile( final AsyncRunnable runnable, final Predicate shouldRetry) { - return this.thenRun(callback -> { + return thenRun(callback -> { new RetryingAsyncCallbackSupplier( new RetryState(), (rs, lastAttemptFailure) -> shouldRetry.test(lastAttemptFailure), - cb -> runnable.finish(cb) + cb -> runnable.finish(cb) // finish is required here, to handle exceptions ).get(callback); }); } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index 1f2e2ad8152..dce706fe228 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -19,11 +19,11 @@ import java.util.function.Predicate; /** - * See AsyncRunnableTest for usage + * @see AsyncRunnable */ public interface AsyncSupplier { - void supplyInternal(SingleResultCallback callback); + void internal(SingleResultCallback callback); /** * Must be invoked at end of async chain @@ -32,7 +32,7 @@ public interface AsyncSupplier { default void finish(final SingleResultCallback callback) { final boolean[] callbackInvoked = {false}; try { - this.supplyInternal((v, e) -> { + this.internal((v, e) -> { callbackInvoked[0] = true; callback.onResult(v, e); }); @@ -46,16 +46,48 @@ default void finish(final SingleResultCallback callback) { } /** - * @see AsyncRunnable#onErrorRunIf(Predicate, AsyncRunnable). - * + * @param function The async function to run after this runnable + * @return the composition of this supplier and the function, a supplier + * @param The return type of the resulting supplier + */ + default AsyncSupplier thenApply(final AsyncFunction function) { + return (c) -> { + this.internal((v, e) -> { + if (e == null) { + function.internal(v, c); + } else { + c.onResult(null, e); + } + }); + }; + } + + + /** + * @param consumer The async consumer to run after this supplier + * @return the composition of this supplier and the consumer, a runnable + */ + default AsyncRunnable thenConsume(final AsyncConsumer consumer) { + return (c) -> { + this.internal((v, e) -> { + if (e == null) { + consumer.internal(v, c); + } else { + c.onResult(null, e); + } + }); + }; + } + + /** * @param errorCheck A check, comparable to a catch-if/otherwise-rethrow * @param supplier The branch to execute if the error matches * @return The composition of this, and the conditional branch */ - default AsyncSupplier onErrorSupplyIf( + default AsyncSupplier onErrorIf( final Predicate errorCheck, final AsyncSupplier supplier) { - return (callback) -> this.finish((r, e) -> { + return (callback) -> this.internal((r, e) -> { if (e == null) { callback.onResult(r, null); return; @@ -69,10 +101,11 @@ default AsyncSupplier onErrorSupplyIf( return; } if (errorMatched) { - supplier.finish(callback); - return; + supplier.internal(callback); + } else { + callback.onResult(null, e); } - callback.onResult(null, e); }); } + } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index adca7b688c8..88f8d38a4c5 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -78,6 +78,7 @@ import static com.mongodb.assertions.Assertions.notNull; import static com.mongodb.internal.async.AsyncRunnable.beginAsync; import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback; +import static com.mongodb.internal.connection.Authenticator.shouldAuthenticate; import static com.mongodb.internal.connection.CommandHelper.HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO; import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO_LOWER; @@ -383,7 +384,7 @@ public T sendAndReceive(final CommandMessage message, final Decoder decod try { return sendAndReceiveInternal.get(); } catch (MongoCommandException e) { - if (triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description)) { + if (triggersReauthentication(e) && shouldAuthenticate(authenticator, this.description)) { authenticated.set(false); authenticator.reauthenticate(this); authenticated.set(true); @@ -400,8 +401,8 @@ public void sendAndReceiveAsync(final CommandMessage message, final Decoder< AsyncSupplier sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal( message, decoder, sessionContext, requestContext, operationContext, c); - sendAndReceiveAsyncInternal.onErrorSupplyIf(e -> - triggersReauthentication(e) && Authenticator.shouldAuthenticate(authenticator, this.description), beginAsync() + sendAndReceiveAsyncInternal.onErrorIf(e -> + triggersReauthentication(e) && shouldAuthenticate(authenticator, this.description), beginAsync() .thenRun(c -> { authenticated.set(false); authenticator.reauthenticateAsync(this, c); diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index 78755f1b5e3..d21770a3c99 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -224,7 +224,7 @@ void authenticateAsync( if (accessToken != null) { beginAsync().thenRun(c -> { authenticateAsyncUsing(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken), c); - }).onErrorRunIf(e -> triggersRetry(e), c -> { + }).onErrorIf(e -> triggersRetry(e), c -> { authLockAsync(connection, connectionDescription, c); }).finish(callback); } else { diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java new file mode 100644 index 00000000000..15d9048c15c --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java @@ -0,0 +1,586 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.mongodb.internal.async; + +import com.mongodb.client.TestListener; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.function.Supplier; + +import static com.mongodb.internal.async.AsyncRunnable.beginAsync; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +final class AsyncFunctionsTest { + private final TestListener listener = new TestListener(); + private final InvocationTracker invocationTracker = new InvocationTracker(); + + @Test + void testVariations1() { + /* + In our async code: + 1. a callback is provided as a method parameter + 2. at least one sync method must be converted to async + + To use this API: + 1. start an async chain using the static method + 2. use an appropriate chaining method (then...), which will provide "c" + 3. move all sync code into that method + 4. at the async method, pass in "c" and start a new chaining method + 5. finish by invoking the original "callback" at the end of the chain + + Async methods MUST be preceded by unaffected "plain" sync code (sync + code with no async counterpart), and this code MUST reside above the + affected method, as it appears in the sync code. + + Plain sync code MAY throw exceptions, and SHOULD NOT attempt to handle + them asynchronously. The exceptions will be caught and handled by the + chaining methods that contain this sync code. + + Each async lambda MUST invoke its async method with "c", and MUST return + immediately after invoking that method. It MUST NOT, for example, have + a catch or finally (including close on try-with-resources) after the + invocation of the sync method. + + Always use a braced lambda body with no linebreak before ".", as shown + below, to ensure that the async code can be compared to the sync code. + */ + + // the number of expected variations is often: 1 + N methods invoked + // 1 variation with no exceptions, and N per an exception in each method + assertBehavesSameVariations(2, + () -> { + // single sync method invocations... + sync(1); + }, + (callback) -> { + // ...become a single async invocation, wrapped in begin-thenRun/finish: + beginAsync().thenRun(c -> { + async(1, c); + }).finish(callback); + }); + } + + @Test + void testVariations2() { + // tests pairs + // converting: plain-sync, sync-plain, sync-sync + // (plain-plain does not need an async chain) + + assertBehavesSameVariations(3, + () -> { + // plain (unaffected) invocations... + plain(1); + sync(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + // ...are preserved above affected methods + plain(1); + async(2, c); + }).finish(callback); + }); + + assertBehavesSameVariations(3, + () -> { + // when a plain invocation follows an affected method... + sync(1); + plain(2); + }, + (callback) -> { + // ...it is moved to its own block + beginAsync().thenRun(c -> { + async(1, c); + }).thenRunAndFinish(() -> { + plain(2); + }, callback); + }); + + assertBehavesSameVariations(3, + () -> { + // when an affected method follows an affected method + sync(1); + sync(2); + }, + (callback) -> { + // ...it is moved to its own block + beginAsync().thenRun(c -> { + async(1, c); + }).thenRun(c -> { + async(2, c); + }).finish(callback); + }); + } + + @Test + void testVariations4() { + // tests the sync-sync pair with preceding and ensuing plain methods: + assertBehavesSameVariations(5, + () -> { + plain(11); + sync(1); + plain(22); + sync(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(11); + async(1, c); + }).thenRun(c -> { + plain(22); + async(2, c); + }).finish(callback); + }); + + assertBehavesSameVariations(5, + () -> { + sync(1); + plain(11); + sync(2); + plain(22); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).thenRun(c -> { + plain(11); + async(2, c); + }).thenRunAndFinish(() ->{ + plain(22); + }, callback); + }); + } + + @Test + void testSupply() { + assertBehavesSameVariations(4, + () -> { + sync(0); + plain(1); + return syncReturns(2); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(0, c); + }).thenSupply(c -> { + plain(1); + asyncReturns(2, c); + }).finish(callback); + }); + } + + @SuppressWarnings("ConstantConditions") + @Test + void testFullChain() { + // tests a chain: runnable, producer, function, function, consumer + + assertBehavesSameVariations(14, + () -> { + plain(90); + sync(0); + plain(91); + sync(1); + plain(92); + int v = syncReturns(2); + plain(93); + v = syncReturns(v + 1); + plain(94); + v = syncReturns(v + 10); + plain(95); + sync(v + 100); + plain(96); + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(90); + async(0, c); + }).thenRun(c -> { + plain(91); + async(1, c); + }).thenSupply(c -> { + plain(92); + asyncReturns(2, c); + }).thenApply((v, c) -> { + plain(93); + asyncReturns(v + 1, c); + }).thenApply((v, c) -> { + plain(94); + asyncReturns(v + 10, c); + }).thenConsume((v, c) -> { + plain(95); + async(v + 100, c); + }).thenRunAndFinish(() -> { + plain(96); + }, callback); + }); + } + + @Test + void testVariationsBranching() { + assertBehavesSameVariations(5, + () -> { + if (plainTest(1)) { + sync(2); + } else { + sync(3); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + if (plainTest(1)) { + async(2, c); + } else { + async(3, c); + } + }).finish(callback); + }); + + // 2 : fail on first sync, fail on test + // 3 : true test, sync2, sync3 + // 2 : false test, sync3 + // 7 total + assertBehavesSameVariations(7, + () -> { + sync(0); + if (plainTest(1)) { + sync(2); + } + sync(3); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(0, c); + }).thenRunIf(() -> plainTest(1), c -> { + async(2, c); + }).thenRun(c -> { + async(3, c); + }).finish(callback); + }); + + // an additional affected method within the "if" branch + assertBehavesSameVariations(8, + () -> { + sync(0); + if (plainTest(1)) { + sync(21); + sync(22); + } + sync(3); + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(0, c); + }).thenRunIf(() -> plainTest(1), + beginAsync().thenRun(c -> { + async(21, c); + }).thenRun((c) -> { + async(22, c); + }) + ).thenRun(c -> { + async(3, c); + }).finish(callback); + }); + } + + @Test + void testErrorIf() { + assertBehavesSameVariations(5, + () -> { + try { + return syncReturns(1); + } catch (Exception e) { + if (e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1")) { + return syncReturns(2); + } else { + throw e; + } + } + }, + (callback) -> { + beginAsync().thenSupply(c -> { + asyncReturns(1, c); + }).onErrorIf(e -> e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1"), c -> { + asyncReturns(2, c); + }).finish(callback); + }); + } + + @Test + void testLoop() { + assertBehavesSameVariations(InvocationTracker.DEPTH_LIMIT * 2 + 1, + () -> { + while (true) { + try { + sync(plainTest(0) ? 1 : 2); + break; + } catch (RuntimeException e) { + if (e.getMessage().equals("exception-1")) { + continue; + } + throw e; + } + } + }, + (callback) -> { + beginAsync().thenRunRetryingWhile( + c -> sync(plainTest(0) ? 1 : 2), + e -> e.getMessage().equals("exception-1") + ).finish(callback); + }); + } + + @Test + void testFinally() { + // (in try: normal flow + exception + exception) * (in finally: normal + exception) = 6 + assertBehavesSameVariations(6, + () -> { + try { + plain(1); + sync(2); + } finally { + plain(3); + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + plain(1); + async(2, c); + }).thenAlwaysRunAndFinish(() -> { + plain(3); + }, callback); + }); + } + + @Test + void testInvalid() { + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + throw new IllegalStateException("must not cause second callback invocation"); + }).finish((v, e) -> {}); + }); + assertThrows(IllegalStateException.class, () -> { + beginAsync().thenRun(c -> { + async(3, c); + }).finish((v, e) -> { + throw new IllegalStateException("must not cause second callback invocation"); + }); + }); + } + + // invoked methods: + + private void plain(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else { + listener.add("plain-success-" + i); + } + } + + private boolean plainTest(final int i) { + int cur = invocationTracker.getNextOption(3); + if (cur == 0) { + listener.add("plain-exception-" + i); + throw new RuntimeException("affected method exception-" + i); + } else if (cur == 1) { + listener.add("plain-false-" + i); + return false; + } else { + listener.add("plain-true-" + i); + return true; + } + } + + private void sync(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-success-" + i); + } + } + + private Integer syncReturns(final int i) { + int cur = invocationTracker.getNextOption(2); + if (cur == 0) { + listener.add("affected-exception-" + i); + throw new RuntimeException("exception-" + i); + } else { + listener.add("affected-success-" + i); + return i; + } + } + + private void async(final int i, final SingleResultCallback callback) { + try { + sync(i); + callback.onResult(null, null); + } catch (Throwable t) { + callback.onResult(null, t); + } + } + + private void asyncReturns(final int i, final SingleResultCallback callback) { + try { + callback.onResult(syncReturns(i), null); + } catch (Throwable t) { + callback.onResult(null, t); + } + } + + // assert methods: + + private void assertBehavesSameVariations(final int expectedVariations, final Runnable sync, + final Consumer> async) { + assertBehavesSameVariations( + expectedVariations, + () -> { + sync.run(); + return null; + }, + (c) -> { + async.accept((v, e) -> c.onResult(v, e)); + }); + } + + private void assertBehavesSameVariations(final int expectedVariations, final Supplier sync, + final Consumer> async) { + invocationTracker.reset(); + do { + invocationTracker.startInitialStep(); + assertBehavesSame( + sync, + () -> invocationTracker.startMatchStep(), + async); + + } while (invocationTracker.countDown()); + assertEquals(expectedVariations, invocationTracker.getVariationCount()); + } + + private void assertBehavesSame(final Supplier sync, final Runnable between, final Consumer> async) { + T expectedValue = null; + Throwable expectedException = null; + try { + expectedValue = sync.get(); + } catch (Throwable e) { + expectedException = e; + } + List expectedEvents = listener.getEventStrings(); + + listener.clear(); + between.run(); + + AtomicReference actualValue = new AtomicReference<>(); + AtomicReference actualException = new AtomicReference<>(); + try { + async.accept((v, e) -> { + actualValue.set(v); + actualException.set(e); + }); + } catch (Throwable e) { + fail("async threw instead of using callback"); + } + + // The following code can be used to debug variations: + // System.out.println("==="); + // System.out.println(listener.getEventStrings()); + // System.out.println("==="); + + assertEquals(expectedEvents, listener.getEventStrings(), "steps did not match"); + assertEquals(expectedValue, actualValue.get()); + assertEquals(expectedException == null, actualException.get() == null); + if (expectedException != null) { + assertEquals(expectedException.getMessage(), actualException.get().getMessage()); + assertEquals(expectedException.getClass(), actualException.get().getClass()); + } + + listener.clear(); + } + + /** + * Tracks invocations: allows testing of all variations of a method calls + */ + private static class InvocationTracker { + public static final int DEPTH_LIMIT = 50; + private final List invocationResults = new ArrayList<>(); + private boolean isMatchStep = false; // vs initial step + private int item = 0; + private int variationCount = 0; + + public void reset() { + variationCount = 0; + } + + public void startInitialStep() { + variationCount++; + isMatchStep = false; + item = -1; + } + + public int getNextOption(final int myOptionsSize) { + item++; + if (item >= invocationResults.size()) { + if (isMatchStep) { + fail("result should have been pre-initialized: steps may not match"); + } + if (isWithinDepthLimit()) { + invocationResults.add(myOptionsSize - 1); + } else { + invocationResults.add(0); // choose "0" option, usually an exception + } + } + return invocationResults.get(item); + } + + public void startMatchStep() { + isMatchStep = true; + item = -1; + } + + private boolean countDown() { + while (!invocationResults.isEmpty()) { + int lastItemIndex = invocationResults.size() - 1; + int lastItem = invocationResults.get(lastItemIndex); + if (lastItem > 0) { + // count current digit down by 1, until 0 + invocationResults.set(lastItemIndex, lastItem - 1); + return true; + } else { + // current digit completed, remove (move left) + invocationResults.remove(lastItemIndex); + } + } + return false; + } + + public int getVariationCount() { + return variationCount; + } + + public boolean isWithinDepthLimit() { + return invocationResults.size() < DEPTH_LIMIT; + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java deleted file mode 100644 index 8cefe4de6cb..00000000000 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncRunnableTest.java +++ /dev/null @@ -1,702 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.mongodb.internal.async; - -import org.junit.jupiter.api.Test; - -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; -import java.util.function.Supplier; - -import static com.mongodb.internal.async.AsyncRunnable.beginAsync; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.fail; - -final class AsyncRunnableTest { - private final AtomicInteger i = new AtomicInteger(); - - @Test - void testRunnableRun() { - /* - In our async code: - 1. a callback is provided - 2. at least one sync method must be converted to async - - To do this: - 1. start an async chain using the static method - 2. chain using the appropriate method, which will provide "c" - 3. move all sync code into that method - 4. at the async method, pass in "c" and start a new chained method - 5. complete by invoking the original "callback" at the end of the chain - - Async methods may be preceded by "unaffected" sync code, and this code - will reside above the affected method, as it appears in the sync code. - Below, these "unaffected" methods have no sync/async suffix. - - The return of each chained async method MUST be immediately preceded - by an invocation of the relevant async method using "c". - - Always use a braced lambda body to ensure that the form matches the - corresponding sync code. - */ - assertBehavesSame( - () -> { - multiply(); - incrementSync(); - }, - (callback) -> { - beginAsync().thenRun(c -> { - multiply(); - incrementAsync(c); - }).finish(callback); - }); - } - - @Test - void testRunnableRunSyncException() { - // Preceding sync code might throw an exception, so it SHOULD be moved - // into the chain. In any case, any possible exception thrown by it - // MUST be handled by passing it into the callback. - assertBehavesSame( - () -> { - throwException("msg"); - incrementSync(); - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwException("msg"); - incrementAsync(c); - }).finish(callback); - }); - - } - - @Test - void testRunnableRunMultiple() { - // Code split across multiple affected methods: - assertBehavesSame( - () -> { - multiply(); - incrementSync(); - multiply(); - incrementSync(); - }, - (callback) -> { - beginAsync().thenRun(c -> { - multiply(); - incrementAsync(c); - }).thenRun(c -> { - multiply(); - incrementAsync(c); - }).finish(callback); - }); - } - - @Test - void testRunnableRunMultipleExceptionSkipping() { - // An exception in sync code causes ensuing code to be skipped, and - // split async code behaves in the same way: - assertBehavesSame( - () -> { - throwException("m"); - incrementSync(); - throwException("m2"); - incrementSync(); - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwException("m"); - incrementAsync(c); - }).thenRun(c -> { - throwException("m2"); - incrementAsync(c); - }).finish(callback); - }); - } - - @Test - void testRunnableRunMultipleExceptionInAffectedSkipping() { - // Likewise, an exception in the affected method causes a skip: - assertBehavesSame( - () -> { - multiply(); - throwExceptionSync("msg"); - multiply(); - incrementSync(); - }, - (callback) -> { - beginAsync().thenRun(c -> { - multiply(); - throwExceptionAsync("msg", c); - }).thenRun(c -> { - multiply(); - incrementAsync(c); - }).finish(callback); - }); - } - - @Test - void testRunnableCompleteRunnable() { - // Sometimes, sync code follows the affected method, and it MUST be - // moved into the final method: - assertBehavesSame( - () -> { - incrementSync(); - multiply(); - }, - (callback) -> { - beginAsync().thenRun(c -> { - incrementAsync(c); - }).thenRunAndFinish(() -> { - multiply(); - }, callback); - }); - } - - @Test - void testRunnableCompleteRunnableExceptional() { - // ...this makes it easier to correctly handle its exceptions: - assertBehavesSame( - () -> { - incrementSync(); - throwException("m"); - }, - (callback) -> { - beginAsync().thenRun(c -> { - incrementAsync(c); - }).thenRunAndFinish(() -> { - throwException("m"); - }, callback); - }); - } - - @Test - void testRunnableCompleteRunnableSkippedWhenExceptional() { - // ...and to ensure that it is not executed when it should be skipped: - assertBehavesSame( - () -> { - throwExceptionSync("msg"); - multiply(); - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwExceptionAsync("msg", c); - }).thenRunAndFinish(() -> { - multiply(); - }, callback); - }); - } - - @Test - void testRunnableCompleteAlways() { - // normal flow - assertBehavesSame( - () -> { - try { - multiply(); - incrementSync(); - } finally { - multiply(); - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - multiply(); - incrementAsync(c); - }).thenAlwaysRunAndFinish(() -> { - multiply(); - }, callback); - }); - - } - - @Test - void testRunnableCompleteAlwaysExceptionInAffected() { - // exception in sync/async - assertBehavesSame( - () -> { - try { - multiply(); - throwExceptionSync("msg"); - } finally { - multiply(); - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - multiply(); - throwExceptionAsync("msg", c); - }).thenAlwaysRunAndFinish(() -> { - multiply(); - }, callback); - }); - } - - @Test - void testRunnableCompleteAlwaysExceptionInUnaffected() { - // exception in unaffected code - assertBehavesSame( - () -> { - try { - throwException("msg"); - incrementSync(); - } finally { - multiply(); - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwException("msg"); - incrementAsync(c); - }).thenAlwaysRunAndFinish(() -> { - multiply(); - }, callback); - }); - } - - @Test - void testRunnableCompleteAlwaysExceptionInFinally() { - // exception in finally - assertBehavesSame( - () -> { - try { - multiply(); - incrementSync(); - } finally { - throwException("msg"); - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - multiply(); - incrementAsync(c); - }).thenAlwaysRunAndFinish(() -> { - throwException("msg"); - }, callback); - }); - } - - @Test - void testRunnableCompleteAlwaysExceptionInFinallyExceptional() { - // exception in finally, exceptional flow - assertBehavesSame( - () -> { - try { - throwException("first"); - incrementSync(); - } finally { - throwException("msg"); - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwException("first"); - incrementAsync(c); - }).thenAlwaysRunAndFinish(() -> { - throwException("msg"); - }, callback); - }); - } - - @Test - void testRunnableSupply() { - assertBehavesSame( - () -> { - multiply(); - return valueSync(1); - }, - (callback) -> { - beginAsync().thenSupply(c -> { - multiply(); - valueAsync(1, c); - }).finish(callback); - }); - } - - @Test - void testRunnableSupplyExceptional() { - assertBehavesSame( - () -> { - throwException("msg"); - return valueSync(1); - }, - (callback) -> { - beginAsync().thenSupply(c -> { - throwException("msg"); - valueAsync(1, c); - }).finish(callback); - }); - } - - @Test - void testRunnableSupplyExceptionalInAffected() { - assertBehavesSame( - () -> { - throwExceptionSync("msg"); - return valueSync(1); - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwExceptionAsync("msg", c); - }).thenSupply(c -> { - valueAsync(1, c); - }).finish(callback); - }); - } - - @Test - void testSupplierOnErrorIf() { - // no exception - assertBehavesSame( - () -> { - try { - return valueSync(1); - } catch (Exception e) { - if (e.getMessage().equals("m1")) { - return valueSync(2); - } else { - throw e; - } - } - }, - (SingleResultCallback callback) -> { - beginAsync().thenSupply(c -> { - valueAsync(1, c); - }).onErrorSupplyIf(e -> e.getMessage().equals("m1"), c -> { - valueAsync(2, c); - }).finish(callback); - }); - } - - @Test - void testSupplierOnErrorIfWithValueBranch() { - // exception, with value branch - assertBehavesSame( - () -> { - try { - return throwExceptionSync("m1"); - } catch (Exception e) { - if (e.getMessage().equals("m1")) { - return valueSync(2); - } else { - throw e; - } - } - }, - (callback) -> { - beginAsync().thenSupply(c -> { - throwExceptionAsync("m1", c); - }).onErrorSupplyIf(e -> e.getMessage().equals("m1"), c -> { - valueAsync(2, c); - }).finish(callback); - }); - - } - - @Test - void testSupplierOnErrorIfWithExceptionBranch() { - // exception, with exception branch - assertBehavesSame( - () -> { - try { - return throwExceptionSync("m1"); - } catch (Exception e) { - if (e.getMessage().equals("m1")) { - return this.throwExceptionSync("m2"); - } else { - throw e; - } - } - }, - (callback) -> { - beginAsync().thenSupply(c -> { - throwExceptionAsync("m1", c); - }).onErrorSupplyIf(e -> e.getMessage().equals("m1"), c -> { - throwExceptionAsync("m2", c); - }).finish(callback); - }); - } - - @Test - void testRunnableOnErrorIfNoException() { - // no exception - assertBehavesSame( - () -> { - try { - incrementSync(); - } catch (Exception e) { - if (e.getMessage().equals("m1")) { - multiply(); - incrementSync(); - } else { - throw e; - } - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - incrementSync(); - }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { - multiply(); - incrementAsync(c); - }).finish(callback); - }); - - } - - @Test - void testRunnableOnErrorIfThrowsMatching() { - // throws matching exception - assertBehavesSame( - () -> { - try { - throwExceptionSync("m1"); - } catch (Exception e) { - if (e.getMessage().equals("m1")) { - multiply(); - incrementSync(); - } else { - throw e; - } - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwExceptionAsync("m1", c); - }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { - multiply(); - incrementAsync(c); - }).finish(callback); - }); - - } - - @Test - void testRunnableOnErrorIfThrowsNonMatching() { - // throws non-matching exception - assertBehavesSame( - () -> { - try { - throwExceptionSync("not-m1"); - } catch (Exception e) { - if (e.getMessage().equals("m1")) { - multiply(); - incrementSync(); - } else { - throw e; - } - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwExceptionAsync("not-m1", c); - }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { - multiply(); - incrementAsync(c); - }).finish(callback); - }); - } - - @Test - void testRunnableOnErrorIfCheckFails() { - // throws but check fails with exception - assertBehavesSame( - () -> { - try { - throwExceptionSync("m1"); - } catch (Exception e) { - if (throwException("check fails")) { - multiply(); - incrementSync(); - } else { - throw e; - } - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwExceptionAsync("m1", c); - }).onErrorRunIf(e -> throwException("check fails"), c -> { - multiply(); - incrementAsync(c); - }).finish(callback); - }); - } - - @Test - void testRunnableOnErrorIfSyncBranchfails() { - // throws but sync code in branch fails - assertBehavesSame( - () -> { - try { - throwExceptionSync("m1"); - } catch (Exception e) { - if (e.getMessage().equals("m1")) { - throwException("branch"); - incrementSync(); - } else { - throw e; - } - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwExceptionAsync("m1", c); - }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { - throwException("branch"); - incrementAsync(c); - }).finish(callback); - }); - } - - @Test - void testRunnableOnErrorIfSyncBranchFailsWithMatching() { - // throws but sync code in branch fails with matching exception - assertBehavesSame( - () -> { - try { - throwExceptionSync("m1"); - } catch (Exception e) { - if (e.getMessage().equals("m1")) { - multiply(); - throwException("m1"); - incrementSync(); - } else { - throw e; - } - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwExceptionAsync("m1", c); - }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { - multiply(); - throwException("m1"); - incrementAsync(c); - }).finish(callback); - }); - } - - @Test - void testRunnableOnErrorIfThrowsAndBranchedAffectedMethodThrows() { - // throws, and branch sync/async method throws - assertBehavesSame( - () -> { - try { - throwExceptionSync("m1"); - } catch (Exception e) { - if (e.getMessage().equals("m1")) { - multiply(); - throwExceptionSync("m1"); - } else { - throw e; - } - } - }, - (callback) -> { - beginAsync().thenRun(c -> { - throwExceptionAsync("m1", c); - }).onErrorRunIf(e -> e.getMessage().equals("m1"), c -> { - multiply(); - throwExceptionAsync("m1", c); - }).finish(callback); - }); - } - - // unaffected methods: - - private T throwException(final String message) { - throw new RuntimeException(message); - } - - private void multiply() { - i.set(i.get() * 10); - } - - // affected sync-async pairs: - - private void incrementSync() { - i.addAndGet(1); - } - - private void incrementAsync(final SingleResultCallback callback) { - i.addAndGet(1); - callback.onResult(null, null); - } - - private T throwExceptionSync(final String msg) { - throw new RuntimeException(msg); - } - - private void throwExceptionAsync(final String msg, final SingleResultCallback callback) { - try { - throw new RuntimeException(msg); - } catch (Exception e) { - callback.onResult(null, e); - } - } - - private Integer valueSync(final int i) { - return i; - } - - private void valueAsync(final int i, final SingleResultCallback callback) { - callback.onResult(i, null); - } - - private void assertBehavesSame(final Runnable sync, final Consumer> async) { - assertBehavesSame( - () -> { - sync.run(); - return null; - }, - (c) -> { - async.accept((v, e) -> c.onResult(v, e)); - }); - } - - private void assertBehavesSame(final Supplier sync, final Consumer> async) { - AtomicReference actualValue = new AtomicReference<>(); - AtomicReference actualException = new AtomicReference<>(); - try { - i.set(1); - SingleResultCallback callback = (v, e) -> { - actualValue.set(v); - actualException.set(e); - }; - async.accept(callback); - } catch (Throwable e) { - fail("async threw instead of using callback"); - } - Integer expectedI = i.get(); - - try { - i.set(1); - T expectedValue = sync.get(); - assertEquals(expectedValue, actualValue.get()); - assertNull(actualException.get()); - } catch (Throwable e) { - assertNull(actualValue.get()); - assertNotNull(actualException.get(), "async failed to throw expected: " + e); - assertEquals(e.getClass(), actualException.get().getClass()); - assertEquals(e.getMessage(), actualException.get().getMessage()); - } - assertEquals(expectedI, i.get()); - } -} From e7d4ffdaa7c258bf2a9f02f6ee0cfc0ff32e13a6 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Fri, 16 Jun 2023 12:09:51 -0600 Subject: [PATCH 09/12] Fixes, naming --- .../com/mongodb/internal/async/AsyncConsumer.java | 3 +++ .../com/mongodb/internal/async/AsyncFunction.java | 3 +++ .../com/mongodb/internal/async/AsyncRunnable.java | 3 +++ .../com/mongodb/internal/async/AsyncSupplier.java | 3 +++ .../internal/connection/OidcAuthenticator.java | 12 ++++++------ 5 files changed, 18 insertions(+), 6 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java index f3d54e172fa..7bdc08abee9 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java @@ -17,7 +17,10 @@ package com.mongodb.internal.async; /** + * This class is not part of the public API and may be removed or changed at any time> + * * @see AsyncRunnable */ +@FunctionalInterface public interface AsyncConsumer extends AsyncFunction { } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java index 0776edf815a..0f3f1c097d3 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java @@ -19,8 +19,11 @@ import com.mongodb.lang.Nullable; /** + * This class is not part of the public API and may be removed or changed at any time + * * @see AsyncRunnable */ +@FunctionalInterface public interface AsyncFunction { void internal(@Nullable T value, SingleResultCallback callback); } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index 9b5b1158312..ee9bc7ca4ed 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -24,7 +24,10 @@ /** * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time */ +@FunctionalInterface public interface AsyncRunnable extends AsyncSupplier { static AsyncRunnable beginAsync() { diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index dce706fe228..62a4a13971c 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -19,8 +19,11 @@ import java.util.function.Predicate; /** + * This class is not part of the public API and may be removed or changed at any time + * * @see AsyncRunnable */ +@FunctionalInterface public interface AsyncSupplier { void internal(SingleResultCallback callback); diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index d21770a3c99..fa3c0debf7f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -201,7 +201,7 @@ public void authenticate(final InternalConnection connection, final ConnectionDe String accessToken = getValidCachedAccessToken(); if (accessToken != null) { try { - authenticateUsing(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); + authenticateUsingFunction(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); } catch (MongoSecurityException e) { if (triggersRetry(e)) { authLock(connection, connectionDescription); @@ -223,7 +223,7 @@ void authenticateAsync( String accessToken = getValidCachedAccessToken(); if (accessToken != null) { beginAsync().thenRun(c -> { - authenticateAsyncUsing(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken), c); + authenticateUsingFunctionAsync(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken), c); }).onErrorIf(e -> triggersRetry(e), c -> { authLockAsync(connection, connectionDescription, c); }).finish(callback); @@ -244,14 +244,14 @@ private static boolean triggersRetry(@Nullable final Throwable t) { return false; } - private void authenticateAsyncUsing(final InternalConnection connection, + private void authenticateUsingFunctionAsync(final InternalConnection connection, final ConnectionDescription connectionDescription, final Function evaluateChallengeFunction, final SingleResultCallback callback) { this.evaluateChallengeFunction = evaluateChallengeFunction; super.authenticateAsync(connection, connectionDescription, callback); } - private void authenticateUsing( + private void authenticateUsingFunction( final InternalConnection connection, final ConnectionDescription connectionDescription, final Function evaluateChallengeFunction) { @@ -264,7 +264,7 @@ private void authLock(final InternalConnection connection, final ConnectionDescr Locks.withLock(getMongoCredentialWithCache().getOidcLock(), () -> { while (true) { try { - authenticateUsing(connection, description, (challenge) -> evaluate(challenge)); + authenticateUsingFunction(connection, description, (challenge) -> evaluate(challenge)); break; } catch (MongoSecurityException e) { if (triggersRetry(e) && shouldRetryHandler()) { @@ -281,7 +281,7 @@ private void authLockAsync(final InternalConnection connection, final Connection fallbackState = FallbackState.INITIAL; Locks.withLockAsync(getMongoCredentialWithCache().getOidcLock(), beginAsync().thenRunRetryingWhile( - c -> authenticateAsyncUsing(connection, description, (challenge) -> evaluate(challenge), c), + c -> authenticateUsingFunctionAsync(connection, description, (challenge) -> evaluate(challenge), c), e -> triggersRetry(e) && shouldRetryHandler() ), callback); } From 4825ebe38833490101f566e6f59f4a5e9a3a015a Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 19 Jun 2023 12:20:22 -0600 Subject: [PATCH 10/12] Fix inheritance, refactor, more tests --- .../mongodb/internal/async/AsyncConsumer.java | 6 +- .../mongodb/internal/async/AsyncFunction.java | 10 ++- .../mongodb/internal/async/AsyncRunnable.java | 14 ++--- .../mongodb/internal/async/AsyncSupplier.java | 48 +++++++++++---- .../async/function/AsyncCallbackRunnable.java | 13 ---- .../connection/InternalStreamConnection.java | 50 +++++++++------ .../connection/OidcAuthenticator.java | 53 +++++++++------- .../internal/async/AsyncFunctionsTest.java | 61 +++++++++++++++++-- 8 files changed, 171 insertions(+), 84 deletions(-) diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java index 7bdc08abee9..b385670ae88 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncConsumer.java @@ -17,9 +17,9 @@ package com.mongodb.internal.async; /** - * This class is not part of the public API and may be removed or changed at any time> - * - * @see AsyncRunnable + * See tests for usage (AsyncFunctionsTest). + *

+ * This class is not part of the public API and may be removed or changed at any time */ @FunctionalInterface public interface AsyncConsumer extends AsyncFunction { diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java index 0f3f1c097d3..8caf176dce6 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncFunction.java @@ -19,11 +19,15 @@ import com.mongodb.lang.Nullable; /** + * See tests for usage (AsyncFunctionsTest). + *

* This class is not part of the public API and may be removed or changed at any time - * - * @see AsyncRunnable */ @FunctionalInterface public interface AsyncFunction { - void internal(@Nullable T value, SingleResultCallback callback); + /** + * This should not be called externally, but should be implemented as a + * lambda. To "finish" an async chain, use one of the "finish" methods. + */ + void unsafeFinish(@Nullable T value, SingleResultCallback callback); } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index ee9bc7ca4ed..8d4ee54d7f5 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -28,7 +28,7 @@ * This class is not part of the public API and may be removed or changed at any time */ @FunctionalInterface -public interface AsyncRunnable extends AsyncSupplier { +public interface AsyncRunnable extends AsyncSupplier, AsyncConsumer { static AsyncRunnable beginAsync() { return (c) -> c.onResult(null, null); @@ -83,9 +83,9 @@ default void thenAlwaysRunAndFinish(final Runnable runnable, final SingleResultC */ default AsyncRunnable thenRun(final AsyncRunnable runnable) { return (c) -> { - this.internal((r, e) -> { + this.unsafeFinish((r, e) -> { if (e == null) { - runnable.internal(c); + runnable.unsafeFinish(c); } else { c.onResult(null, e); } @@ -101,7 +101,7 @@ default AsyncRunnable thenRun(final AsyncRunnable runnable) { */ default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRunnable runnable) { return (callback) -> { - this.internal((r, e) -> { + this.unsafeFinish((r, e) -> { if (e != null) { callback.onResult(null, e); return; @@ -114,7 +114,7 @@ default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRu return; } if (matched) { - runnable.internal(callback); + runnable.unsafeFinish(callback); } else { callback.onResult(null, null); } @@ -129,9 +129,9 @@ default AsyncRunnable thenRunIf(final Supplier condition, final AsyncRu */ default AsyncSupplier thenSupply(final AsyncSupplier supplier) { return (c) -> { - this.internal((r, e) -> { + this.unsafeFinish((r, e) -> { if (e == null) { - supplier.internal(c); + supplier.unsafeFinish(c); } else { c.onResult(null, e); } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index 62a4a13971c..78575e3fd1a 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -16,26 +16,50 @@ package com.mongodb.internal.async; +import com.mongodb.lang.Nullable; + import java.util.function.Predicate; + /** + * See tests for usage (AsyncFunctionsTest). + *

* This class is not part of the public API and may be removed or changed at any time - * - * @see AsyncRunnable */ @FunctionalInterface -public interface AsyncSupplier { +public interface AsyncSupplier extends AsyncFunction { + /** + * This should not be called externally to this API. It should be + * implemented as a lambda. To "finish" an async chain, use one of + * the "finish" methods. + * + * @see #finish(SingleResultCallback) + */ + void unsafeFinish(SingleResultCallback callback); - void internal(SingleResultCallback callback); + /** + * This method must only be used when this AsyncSupplier corresponds + * to a {@link java.util.function.Supplier} (and is therefore being + * used within an async chain method lambda). + * @param callback the callback + */ + default void getAsync(final SingleResultCallback callback) { + unsafeFinish(callback); + } + + @Override + default void unsafeFinish(@Nullable final Void value, final SingleResultCallback callback) { + unsafeFinish(callback); + } /** - * Must be invoked at end of async chain + * Must be invoked at end of async chain. * @param callback the callback provided by the method the chain is used in */ default void finish(final SingleResultCallback callback) { final boolean[] callbackInvoked = {false}; try { - this.internal((v, e) -> { + this.unsafeFinish((v, e) -> { callbackInvoked[0] = true; callback.onResult(v, e); }); @@ -55,9 +79,9 @@ default void finish(final SingleResultCallback callback) { */ default AsyncSupplier thenApply(final AsyncFunction function) { return (c) -> { - this.internal((v, e) -> { + this.unsafeFinish((v, e) -> { if (e == null) { - function.internal(v, c); + function.unsafeFinish(v, c); } else { c.onResult(null, e); } @@ -72,9 +96,9 @@ default AsyncSupplier thenApply(final AsyncFunction function) { */ default AsyncRunnable thenConsume(final AsyncConsumer consumer) { return (c) -> { - this.internal((v, e) -> { + this.unsafeFinish((v, e) -> { if (e == null) { - consumer.internal(v, c); + consumer.unsafeFinish(v, c); } else { c.onResult(null, e); } @@ -90,7 +114,7 @@ default AsyncRunnable thenConsume(final AsyncConsumer consumer) { default AsyncSupplier onErrorIf( final Predicate errorCheck, final AsyncSupplier supplier) { - return (callback) -> this.internal((r, e) -> { + return (callback) -> this.unsafeFinish((r, e) -> { if (e == null) { callback.onResult(r, null); return; @@ -104,7 +128,7 @@ default AsyncSupplier onErrorIf( return; } if (errorMatched) { - supplier.internal(callback); + supplier.unsafeFinish(callback); } else { callback.onResult(null, e); } diff --git a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java index 7304a9ef9b5..02fdbdf9699 100644 --- a/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/function/AsyncCallbackRunnable.java @@ -32,17 +32,4 @@ public interface AsyncCallbackRunnable { */ void run(SingleResultCallback callback); - /** - * Converts this {@link AsyncCallbackSupplier} to {@link AsyncCallbackSupplier}{@code }. - */ - default AsyncCallbackSupplier asSupplier() { - return this::run; - } - - /** - * @see AsyncCallbackSupplier#whenComplete(Runnable) - */ - default AsyncCallbackRunnable whenComplete(final Runnable after) { - return callback -> asSupplier().whenComplete(after).get(callback); - } } diff --git a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java index 88f8d38a4c5..e9fa002bb38 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java +++ b/driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java @@ -384,36 +384,48 @@ public T sendAndReceive(final CommandMessage message, final Decoder decod try { return sendAndReceiveInternal.get(); } catch (MongoCommandException e) { - if (triggersReauthentication(e) && shouldAuthenticate(authenticator, this.description)) { - authenticated.set(false); - authenticator.reauthenticate(this); - authenticated.set(true); - return sendAndReceiveInternal.get(); + if (reauthenticationIsTriggered(e)) { + return reauthenticateAndRetry(sendAndReceiveInternal); } throw e; } } - @Override public void sendAndReceiveAsync(final CommandMessage message, final Decoder decoder, final SessionContext sessionContext, final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback callback) { AsyncSupplier sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal( message, decoder, sessionContext, requestContext, operationContext, c); - sendAndReceiveAsyncInternal.onErrorIf(e -> - triggersReauthentication(e) && shouldAuthenticate(authenticator, this.description), beginAsync() - .thenRun(c -> { - authenticated.set(false); - authenticator.reauthenticateAsync(this, c); - }).thenSupply((c) -> { - authenticated.set(true); - sendAndReceiveAsyncInternal.finish(c); - })) - .finish(callback); - } - - public static boolean triggersReauthentication(@Nullable final Throwable t) { + beginAsync().thenSupply(c -> { + sendAndReceiveAsyncInternal.getAsync(c); + }).onErrorIf(e -> reauthenticationIsTriggered(e), c -> { + reauthenticateAndRetryAsync(sendAndReceiveAsyncInternal, c); + }).finish(callback); + } + + private T reauthenticateAndRetry(final Supplier operation) { + authenticated.set(false); + assertNotNull(authenticator).reauthenticate(this); + authenticated.set(true); + return operation.get(); + } + + private void reauthenticateAndRetryAsync(final AsyncSupplier operation, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + authenticated.set(false); + assertNotNull(authenticator).reauthenticateAsync(this, c); + }).thenSupply((c) -> { + authenticated.set(true); + operation.getAsync(c); + }).finish(callback); + } + + public boolean reauthenticationIsTriggered(@Nullable final Throwable t) { + if (!shouldAuthenticate(authenticator, this.description)) { + return false; + } if (t instanceof MongoCommandException) { MongoCommandException e = (MongoCommandException) t; return e.getErrorCode() == 391; diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index fa3c0debf7f..fe86e138e2f 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -200,38 +200,49 @@ public void authenticate(final InternalConnection connection, final ConnectionDe assertFalse(connection.opened()); String accessToken = getValidCachedAccessToken(); if (accessToken != null) { - try { - authenticateUsingFunction(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); - } catch (MongoSecurityException e) { - if (triggersRetry(e)) { - authLock(connection, connectionDescription); - } else { - throw e; - } - } + authenticateOptimistically(connection, connectionDescription, accessToken); } else { authLock(connection, connectionDescription); } } @Override - void authenticateAsync( - final InternalConnection connection, - final ConnectionDescription connectionDescription, + void authenticateAsync(final InternalConnection connection, final ConnectionDescription connectionDescription, final SingleResultCallback callback) { - assertFalse(connection.opened()); - String accessToken = getValidCachedAccessToken(); - if (accessToken != null) { - beginAsync().thenRun(c -> { - authenticateUsingFunctionAsync(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken), c); - }).onErrorIf(e -> triggersRetry(e), c -> { + beginAsync().thenRun(c -> { + assertFalse(connection.opened()); + String accessToken = getValidCachedAccessToken(); + if (accessToken != null) { + authenticateOptimisticallyAsync(connection, connectionDescription, accessToken, c); + } else { authLockAsync(connection, connectionDescription, c); - }).finish(callback); - } else { - authLockAsync(connection, connectionDescription, callback); + } + }).finish(callback); + } + + private void authenticateOptimistically(final InternalConnection connection, + final ConnectionDescription connectionDescription, final String accessToken) { + try { + authenticateUsingFunction(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken)); + } catch (MongoSecurityException e) { + if (triggersRetry(e)) { + authLock(connection, connectionDescription); + } else { + throw e; + } } } + private void authenticateOptimisticallyAsync(final InternalConnection connection, + final ConnectionDescription connectionDescription, final String accessToken, + final SingleResultCallback callback) { + beginAsync().thenRun(c -> { + authenticateUsingFunctionAsync(connection, connectionDescription, (challenge) -> prepareTokenAsJwt(accessToken), c); + }).onErrorIf(e -> triggersRetry(e), c -> { + authLockAsync(connection, connectionDescription, c); + }).finish(callback); + } + private static boolean triggersRetry(@Nullable final Throwable t) { if (t instanceof MongoSecurityException) { MongoSecurityException e = (MongoSecurityException) t; diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java index 15d9048c15c..009fb6c7dde 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java @@ -41,15 +41,18 @@ void testVariations1() { 2. at least one sync method must be converted to async To use this API: - 1. start an async chain using the static method + 1. start an async chain using the "beginAsync" static method 2. use an appropriate chaining method (then...), which will provide "c" - 3. move all sync code into that method + 3. copy all sync code to that method 4. at the async method, pass in "c" and start a new chaining method - 5. finish by invoking the original "callback" at the end of the chain + 5. provide the original "callback" at the end of the chain via "finish" Async methods MUST be preceded by unaffected "plain" sync code (sync code with no async counterpart), and this code MUST reside above the - affected method, as it appears in the sync code. + affected method, as it appears in the sync code. Plain code after + the sync method should be supplied via one of the "finally" variants. + Safe "shared" plain code (variable and lambda declarations) which cannot + throw, may remain outside the chained invocations, for convenience. Plain sync code MAY throw exceptions, and SHOULD NOT attempt to handle them asynchronously. The exceptions will be caught and handled by the @@ -60,8 +63,9 @@ void testVariations1() { a catch or finally (including close on try-with-resources) after the invocation of the sync method. - Always use a braced lambda body with no linebreak before ".", as shown - below, to ensure that the async code can be compared to the sync code. + A braced lambda body (with no linebreak before "."), as shown below, + should be used, as this will be consistent with other usages, and allows + the async code to be more easily compared to the sync code. */ // the number of expected variations is often: 1 + N methods invoked @@ -369,6 +373,51 @@ void testFinally() { }); } + @Test + void testUsedAsLambda() { + assertBehavesSameVariations(4, + () -> { + Supplier s = () -> syncReturns(9); + sync(0); + plain(1); + return s.get(); + }, + (callback) -> { + AsyncSupplier s = (c) -> asyncReturns(9, c); + beginAsync().thenRun(c -> { + async(0, c); + }).thenSupply((c) -> { + plain(1); + s.getAsync(c); + }).finish(callback); + }); + } + + @Test + void testVariables() { + assertBehavesSameVariations(3, + () -> { + int something; + something = 90; + sync(something); + something = something + 10; + sync(something); + }, + (callback) -> { + // Certain variables may need to be shared; these can be + // declared (but not initialized) outside the async chain. + // Any container works (atomic allowed but not needed) + final int[] something = new int[1]; + beginAsync().thenRun(c -> { + something[0] = 90; + async(something[0], c); + }).thenRun((c) -> { + something[0] = something[0] + 10; + async(something[0], c); + }).finish(callback); + }); + } + @Test void testInvalid() { assertThrows(IllegalStateException.class, () -> { From ad85f8be899ad48dabada557cea968a5dbc4b026 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Fri, 23 Jun 2023 13:16:41 -0600 Subject: [PATCH 11/12] Remove redundant invocation --- .../main/com/mongodb/internal/connection/OidcAuthenticator.java | 1 - 1 file changed, 1 deletion(-) diff --git a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java index fe86e138e2f..18a54311496 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java +++ b/driver-core/src/main/com/mongodb/internal/connection/OidcAuthenticator.java @@ -188,7 +188,6 @@ public void reauthenticate(final InternalConnection connection) { @Override public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback callback) { - assertTrue(connection.opened()); beginAsync().thenRun(c -> { assertTrue(connection.opened()); authLockAsync(connection, connection.getDescription(), c); From 8e32eed21e9f08b5ef8f206e3d892a49ab933849 Mon Sep 17 00:00:00 2001 From: Maxim Katcharov Date: Mon, 26 Jun 2023 14:53:06 -0600 Subject: [PATCH 12/12] PR fixes --- .../org/bson/codecs/kotlin/DataClassCodec.kt | 4 ++-- .../codecs/kotlin/DataClassCodecProvider.kt | 5 +++-- .../kotlin/DataClassCodecProviderTest.kt | 8 +++---- .../mongodb/internal/async/AsyncRunnable.java | 2 +- .../mongodb/internal/async/AsyncSupplier.java | 2 +- .../internal/async/AsyncFunctionsTest.java | 22 +++++++++++++++++++ 6 files changed, 33 insertions(+), 10 deletions(-) diff --git a/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodec.kt b/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodec.kt index 2b67e1de0c3..da84033c521 100644 --- a/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodec.kt +++ b/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodec.kt @@ -139,8 +139,8 @@ internal data class DataClassCodec( validateAnnotations(kClass) val primaryConstructor = kClass.primaryConstructor ?: throw CodecConfigurationException("No primary constructor for $kClass") - val typeMap = - types.mapIndexed { i, k -> primaryConstructor.typeParameters[i].createType() to k }.toMap() + val typeMap = types.mapIndexed { i, k -> primaryConstructor.typeParameters[i].createType() to k } + .toMap() val propertyModels = primaryConstructor.parameters.map { kParameter -> diff --git a/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodecProvider.kt b/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodecProvider.kt index 962741033e1..e13d74116c7 100644 --- a/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodecProvider.kt +++ b/bson-kotlin/src/main/kotlin/org/bson/codecs/kotlin/DataClassCodecProvider.kt @@ -15,14 +15,15 @@ */ package org.bson.codecs.kotlin -import java.lang.reflect.Type import org.bson.codecs.Codec import org.bson.codecs.configuration.CodecProvider import org.bson.codecs.configuration.CodecRegistry +import java.lang.reflect.Type /** A Kotlin reflection based Codec Provider for data classes */ public class DataClassCodecProvider : CodecProvider { - override fun get(clazz: Class, registry: CodecRegistry): Codec? = get(clazz, emptyList(), registry) + override fun get(clazz: Class, registry: CodecRegistry): Codec? = + get(clazz, emptyList(), registry) override fun get(clazz: Class, typeArguments: List, registry: CodecRegistry): Codec? = DataClassCodec.create(clazz.kotlin, registry, typeArguments) diff --git a/bson-kotlin/src/test/kotlin/org/bson/codecs/kotlin/DataClassCodecProviderTest.kt b/bson-kotlin/src/test/kotlin/org/bson/codecs/kotlin/DataClassCodecProviderTest.kt index e0c7f9d1d1b..b36ada9622a 100644 --- a/bson-kotlin/src/test/kotlin/org/bson/codecs/kotlin/DataClassCodecProviderTest.kt +++ b/bson-kotlin/src/test/kotlin/org/bson/codecs/kotlin/DataClassCodecProviderTest.kt @@ -16,16 +16,16 @@ package org.bson.codecs.kotlin import com.mongodb.MongoClientSettings -import kotlin.test.assertEquals -import kotlin.test.assertNotNull -import kotlin.test.assertNull -import kotlin.test.assertTrue import org.bson.codecs.configuration.CodecConfigurationException import org.bson.codecs.kotlin.samples.DataClassParameterized import org.bson.codecs.kotlin.samples.DataClassWithSimpleValues import org.bson.conversions.Bson import org.junit.jupiter.api.Test import org.junit.jupiter.api.assertThrows +import kotlin.test.assertEquals +import kotlin.test.assertNotNull +import kotlin.test.assertNull +import kotlin.test.assertTrue class DataClassCodecProviderTest { diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java index 8d4ee54d7f5..b9089252f49 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncRunnable.java @@ -73,7 +73,7 @@ default void thenAlwaysRunAndFinish(final Runnable runnable, final SingleResultC callback.onResult(null, t); return; } - callback.onResult(r, e); + callback.onResult(null, e); }); } diff --git a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java index 78575e3fd1a..7b38595bda9 100644 --- a/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java +++ b/driver-core/src/main/com/mongodb/internal/async/AsyncSupplier.java @@ -73,7 +73,7 @@ default void finish(final SingleResultCallback callback) { } /** - * @param function The async function to run after this runnable + * @param function The async function to run after this supplier * @return the composition of this supplier and the function, a supplier * @param The return type of the resulting supplier */ diff --git a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java index 009fb6c7dde..375f7b5c555 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/async/AsyncFunctionsTest.java @@ -306,6 +306,7 @@ void testVariationsBranching() { @Test void testErrorIf() { + // thenSupply: assertBehavesSameVariations(5, () -> { try { @@ -325,6 +326,27 @@ void testErrorIf() { asyncReturns(2, c); }).finish(callback); }); + + // thenRun: + assertBehavesSameVariations(5, + () -> { + try { + sync(1); + } catch (Exception e) { + if (e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1")) { + sync(2); + } else { + throw e; + } + } + }, + (callback) -> { + beginAsync().thenRun(c -> { + async(1, c); + }).onErrorIf(e -> e.getMessage().equals(plainTest(1) ? "unexpected" : "exception-1"), c -> { + async(2, c); + }).finish(callback); + }); } @Test