Skip to content

Commit 9873b4a

Browse files
committed
Implement OIDC auth for async
JAVA-4981
1 parent b03bfbc commit 9873b4a

File tree

11 files changed

+1129
-17
lines changed

11 files changed

+1129
-17
lines changed

driver-core/src/main/com/mongodb/internal/Locks.java

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package com.mongodb.internal;
1818

1919
import com.mongodb.MongoInterruptedException;
20+
import com.mongodb.internal.async.AsyncRunnable;
21+
import com.mongodb.internal.async.SingleResultCallback;
2022

2123
import java.util.concurrent.locks.Lock;
2224
import java.util.concurrent.locks.StampedLock;
@@ -33,6 +35,26 @@ public static void withLock(final Lock lock, final Runnable action) {
3335
});
3436
}
3537

38+
public static void withLockAsync(final StampedLock lock, final AsyncRunnable runnable,
39+
final SingleResultCallback<Void> callback) {
40+
long stamp;
41+
try {
42+
stamp = lock.writeLockInterruptibly();
43+
} catch (InterruptedException e) {
44+
Thread.currentThread().interrupt();
45+
try {
46+
throw new MongoInterruptedException("Interrupted waiting for lock", e);
47+
} catch (MongoInterruptedException mie) {
48+
callback.onResult(null, mie);
49+
return;
50+
}
51+
}
52+
53+
runnable.completeAlways(() -> {
54+
lock.unlockWrite(stamp);
55+
}, callback);
56+
}
57+
3658
public static <V> V withLock(final StampedLock lock, final Supplier<V> supplier) {
3759
long stamp;
3860
try {
@@ -55,15 +77,15 @@ public static <V> V withLock(final Lock lock, final Supplier<V> supplier) {
5577
public static <V, E extends Exception> V checkedWithLock(final Lock lock, final CheckedSupplier<V, E> supplier) throws E {
5678
try {
5779
lock.lockInterruptibly();
58-
try {
59-
return supplier.get();
60-
} finally {
61-
lock.unlock();
62-
}
6380
} catch (InterruptedException e) {
6481
Thread.currentThread().interrupt();
6582
throw new MongoInterruptedException("Interrupted waiting for lock", e);
6683
}
84+
try {
85+
return supplier.get();
86+
} finally {
87+
lock.unlock();
88+
}
6789
}
6890

6991
private Locks() {
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.internal.async;
18+
19+
import com.mongodb.internal.async.function.RetryState;
20+
import com.mongodb.internal.async.function.RetryingAsyncCallbackSupplier;
21+
22+
import java.util.function.Function;
23+
import java.util.function.Predicate;
24+
25+
/**
26+
* See AsyncRunnableTest for usage
27+
*/
28+
public interface AsyncRunnable {
29+
30+
static AsyncRunnable startAsync() {
31+
return (c) -> c.onResult(null, null);
32+
}
33+
34+
/**
35+
* Must be invoked at end of async chain
36+
* @param callback the callback provided by the method the chain is used in
37+
*/
38+
void complete(SingleResultCallback<Void> callback); // NoResultCallback
39+
40+
/**
41+
* Must be invoked at end of async chain
42+
* @param runnable the sync code to invoke (under non-exceptional flow)
43+
* prior to the callback
44+
* @param callback the callback provided by the method the chain is used in
45+
*/
46+
default void complete(final Runnable runnable, final SingleResultCallback<Void> callback) {
47+
this.complete((r, e) -> {
48+
if (e != null) {
49+
callback.onResult(null, e);
50+
return;
51+
}
52+
try {
53+
runnable.run();
54+
} catch (Throwable t) {
55+
callback.onResult(null, t);
56+
return;
57+
}
58+
callback.onResult(null, null);
59+
});
60+
}
61+
62+
/**
63+
* See {@link #complete(Runnable, SingleResultCallback)}, but the runnable
64+
* will always be executed, including on the exceptional path.
65+
* @param runnable the runnable
66+
* @param callback the callback
67+
*/
68+
default void completeAlways(final Runnable runnable, final SingleResultCallback<Void> callback) {
69+
this.complete((r, e) -> {
70+
try {
71+
runnable.run();
72+
} catch (Throwable t) {
73+
callback.onResult(null, t);
74+
return;
75+
}
76+
callback.onResult(r, e);
77+
});
78+
}
79+
80+
/**
81+
* @param runnable The async runnable to run after this one
82+
* @return the composition of this and the runnable
83+
*/
84+
default AsyncRunnable run(final AsyncRunnable runnable) {
85+
return (c) -> {
86+
this.complete((r, e) -> {
87+
if (e != null) {
88+
c.onResult(null, e);
89+
return;
90+
}
91+
try {
92+
runnable.complete(c);
93+
} catch (Throwable t) {
94+
c.onResult(null, t);
95+
}
96+
});
97+
};
98+
}
99+
100+
/**
101+
* @param supplier The supplier to supply using after this runnable.
102+
* @return the composition of this runnable and the supplier
103+
* @param <T> The return type of the supplier
104+
*/
105+
default <T> AsyncSupplier<T> supply(final AsyncSupplier<T> supplier) {
106+
return (c) -> {
107+
this.complete((r, e) -> {
108+
if (e != null) {
109+
c.onResult(null, e);
110+
return;
111+
}
112+
try {
113+
supplier.complete(c);
114+
} catch (Throwable t) {
115+
c.onResult(null, t);
116+
}
117+
});
118+
};
119+
}
120+
121+
/**
122+
* @param errorCheck A check, comparable to a catch-if/otherwise-rethrow
123+
* @param runnable The branch to execute if the error matches
124+
* @return The composition of this, and the conditional branch
125+
*/
126+
default AsyncRunnable onErrorIf(
127+
final Function<Throwable, Boolean> errorCheck,
128+
final AsyncRunnable runnable) {
129+
return (callback) -> this.complete((r, e) -> {
130+
if (e == null) {
131+
callback.onResult(r, null);
132+
return;
133+
}
134+
try {
135+
Boolean check = errorCheck.apply(e);
136+
if (check) {
137+
runnable.complete(callback);
138+
return;
139+
}
140+
} catch (Throwable t) {
141+
callback.onResult(null, t);
142+
return;
143+
}
144+
callback.onResult(r, e);
145+
});
146+
}
147+
148+
/**
149+
* @see RetryingAsyncCallbackSupplier
150+
* @param shouldRetry condition under which to retry
151+
* @param runnable the runnable to loop
152+
* @return the composition of this, and the looping branch
153+
*/
154+
default AsyncRunnable runRetryingWhen(
155+
final Predicate<Throwable> shouldRetry,
156+
final AsyncRunnable runnable) {
157+
return this.run(callback -> {
158+
new RetryingAsyncCallbackSupplier<Void>(
159+
new RetryState(),
160+
(rs, lastAttemptFailure) -> shouldRetry.test(lastAttemptFailure),
161+
cb -> runnable.complete(cb)
162+
).get(callback);
163+
});
164+
}
165+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright 2008-present MongoDB, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.mongodb.internal.async;
18+
19+
import java.util.function.Function;
20+
21+
/**
22+
* See AsyncRunnableTest for usage
23+
*/
24+
public interface AsyncSupplier<T> {
25+
26+
/**
27+
* Must be invoked at end of async chain
28+
* @param callback the callback provided by the method the chain is used in
29+
*/
30+
void complete(SingleResultCallback<T> callback);
31+
32+
/**
33+
* @see AsyncRunnable#onErrorIf(Function, AsyncRunnable).
34+
*
35+
* @param errorCheck A check, comparable to a catch-if/otherwise-rethrow
36+
* @param supplier The branch to execute if the error matches
37+
* @return The composition of this, and the conditional branch
38+
*/
39+
default AsyncSupplier<T> onErrorIf(
40+
final Function<Throwable, Boolean> errorCheck,
41+
final AsyncSupplier<T> supplier) {
42+
return (callback) -> this.complete((r, e) -> {
43+
if (e == null) {
44+
callback.onResult(r, null);
45+
return;
46+
}
47+
try {
48+
Boolean check = errorCheck.apply(e);
49+
if (check) {
50+
supplier.complete(callback);
51+
return;
52+
}
53+
} catch (Throwable t) {
54+
callback.onResult(null, t);
55+
return;
56+
}
57+
callback.onResult(r, e);
58+
});
59+
}
60+
}

driver-core/src/main/com/mongodb/internal/async/function/RetryingAsyncCallbackSupplier.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ public RetryingAsyncCallbackSupplier(
8484
this.asyncFunction = asyncFunction;
8585
}
8686

87+
public RetryingAsyncCallbackSupplier(
88+
final RetryState state,
89+
final BiPredicate<RetryState, Throwable> retryPredicate,
90+
final AsyncCallbackSupplier<R> asyncFunction) {
91+
this.state = state;
92+
this.retryPredicate = retryPredicate;
93+
this.failedResultTransformer = (previouslyChosenFailure, lastAttemptFailure) -> lastAttemptFailure;
94+
this.asyncFunction = asyncFunction;
95+
}
96+
8797
@Override
8898
public void get(final SingleResultCallback<R> callback) {
8999
/* `asyncFunction` and `callback` are the only externally provided pieces of code for which we do not need to care about

driver-core/src/main/com/mongodb/internal/connection/Authenticator.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,9 @@ public void reauthenticate(final InternalConnection connection) {
104104
authenticate(connection, connection.getDescription());
105105
}
106106

107+
public void reauthenticateAsync(final InternalConnection connection, final SingleResultCallback<Void> callback) {
108+
throw new UnsupportedOperationException(
109+
"Reauthentication requested by server but is not supported by specified mechanism.");
110+
}
111+
107112
}

driver-core/src/main/com/mongodb/internal/connection/InternalStreamConnection.java

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import com.mongodb.connection.StreamFactory;
4545
import com.mongodb.event.CommandListener;
4646
import com.mongodb.internal.VisibleForTesting;
47+
import com.mongodb.internal.async.AsyncSupplier;
4748
import com.mongodb.internal.async.SingleResultCallback;
4849
import com.mongodb.internal.diagnostics.logging.Logger;
4950
import com.mongodb.internal.diagnostics.logging.Loggers;
@@ -74,6 +75,7 @@
7475
import static com.mongodb.assertions.Assertions.assertNotNull;
7576
import static com.mongodb.assertions.Assertions.isTrue;
7677
import static com.mongodb.assertions.Assertions.notNull;
78+
import static com.mongodb.internal.async.AsyncRunnable.startAsync;
7779
import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback;
7880
import static com.mongodb.internal.connection.CommandHelper.HELLO;
7981
import static com.mongodb.internal.connection.CommandHelper.LEGACY_HELLO;
@@ -390,6 +392,31 @@ public <T> T sendAndReceive(final CommandMessage message, final Decoder<T> decod
390392
}
391393
}
392394

395+
396+
@Override
397+
public <T> void sendAndReceiveAsync(final CommandMessage message, final Decoder<T> decoder, final SessionContext sessionContext,
398+
final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback<T> callback) {
399+
notNull("stream is open", stream, callback);
400+
401+
AsyncSupplier<T> sendAndReceiveAsyncInternal = c -> sendAndReceiveAsyncInternal(
402+
message, decoder, sessionContext, requestContext, operationContext, c);
403+
404+
if (!Authenticator.shouldAuthenticate(authenticator, this.description)) {
405+
sendAndReceiveAsyncInternal.complete(callback);
406+
return;
407+
}
408+
409+
sendAndReceiveAsyncInternal.onErrorIf(e -> triggersReauthentication(e), startAsync()
410+
.run(c -> {
411+
authenticated.set(false);
412+
authenticator.reauthenticateAsync(this, c);
413+
}).supply((c) -> {
414+
authenticated.set(true);
415+
sendAndReceiveAsyncInternal.complete(c);
416+
}))
417+
.complete(callback);
418+
}
419+
393420
public static boolean triggersReauthentication(@Nullable final Throwable t) {
394421
if (t instanceof MongoCommandException) {
395422
MongoCommandException e = (MongoCommandException) t;
@@ -518,11 +545,8 @@ private <T> T receiveCommandMessageResponse(final Decoder<T> decoder,
518545
}
519546
}
520547

521-
@Override
522-
public <T> void sendAndReceiveAsync(final CommandMessage message, final Decoder<T> decoder, final SessionContext sessionContext,
548+
private <T> void sendAndReceiveAsyncInternal(final CommandMessage message, final Decoder<T> decoder, final SessionContext sessionContext,
523549
final RequestContext requestContext, final OperationContext operationContext, final SingleResultCallback<T> callback) {
524-
notNull("stream is open", stream, callback);
525-
526550
if (isClosed()) {
527551
callback.onResult(null, new MongoSocketClosedException("Can not read from a closed socket", getServerAddress()));
528552
return;

0 commit comments

Comments
 (0)