33
33
import org .bson .BsonDocument ;
34
34
import org .bson .BsonInt32 ;
35
35
import org .bson .BsonString ;
36
+ import org .bson .Document ;
36
37
import org .junit .jupiter .api .AfterEach ;
37
38
import org .junit .jupiter .api .BeforeEach ;
38
39
import org .junit .jupiter .api .Test ;
71
72
import static org .junit .jupiter .api .Assertions .assertFalse ;
72
73
import static org .junit .jupiter .api .Assertions .assertThrows ;
73
74
import static org .junit .jupiter .api .Assertions .assertTrue ;
74
- import static org .junit .jupiter .api .Assertions .fail ;
75
75
import static org .junit .jupiter .api .Assumptions .assumeTrue ;
76
76
import static util .ThreadTestHelpers .executeAll ;
77
77
@@ -201,7 +201,7 @@ public void test2p2RequestCallbackReturnsNull() {
201
201
//noinspection ConstantConditions
202
202
OidcCallback callback = (context ) -> null ;
203
203
MongoClientSettings clientSettings = this .createSettings (callback );
204
- performFind (clientSettings , MongoConfigurationException .class ,
204
+ assertFindFails (clientSettings , MongoConfigurationException .class ,
205
205
"Result of callback must not be null" );
206
206
}
207
207
@@ -216,12 +216,9 @@ public void test2p3CallbackReturnsMissingData() {
216
216
// we ensure that the error is propagated
217
217
MongoClientSettings clientSettings = createSettings (callback );
218
218
try (MongoClient mongoClient = createMongoClient (clientSettings )) {
219
- try {
220
- performFind (mongoClient );
221
- fail ();
222
- } catch (Exception e ) {
223
- assertCause (IllegalArgumentException .class , "accessToken can not be null" , e );
224
- }
219
+ assertCause (IllegalArgumentException .class ,
220
+ "accessToken can not be null" ,
221
+ () -> performFind (mongoClient ));
225
222
}
226
223
}
227
224
@@ -230,13 +227,9 @@ public void test2p4InvalidClientConfigurationWithCallback() {
230
227
String uri = getOidcUri () + "&authMechanismProperties=ENVIRONMENT:" + getOidcEnv ();
231
228
MongoClientSettings settings = createSettings (
232
229
uri , createCallback (), null , OIDC_CALLBACK_KEY );
233
- try {
234
- performFind (settings );
235
- fail ();
236
- } catch (Exception e ) {
237
- assertCause (IllegalArgumentException .class ,
238
- "OIDC_CALLBACK must not be specified when ENVIRONMENT is specified" , e );
239
- }
230
+ assertCause (IllegalArgumentException .class ,
231
+ "OIDC_CALLBACK must not be specified when ENVIRONMENT is specified" ,
232
+ () -> performFind (settings ));
240
233
}
241
234
242
235
@ Test
@@ -282,13 +275,9 @@ public void test3p2AuthFailsWithoutCachedToken() {
282
275
(x ) -> new OidcCallbackResult ("invalid_token" , Duration .ZERO );
283
276
MongoClientSettings clientSettings = createSettings (callback );
284
277
try (MongoClient mongoClient = createMongoClient (clientSettings )) {
285
- try {
286
- performFind (mongoClient );
287
- fail ();
288
- } catch (Exception e ) {
289
- assertCause (MongoCommandException .class ,
290
- "Command failed with error 18 (AuthenticationFailed):" , e );
291
- }
278
+ assertCause (MongoCommandException .class ,
279
+ "Command failed with error 18 (AuthenticationFailed):" ,
280
+ () -> performFind (mongoClient ));
292
281
}
293
282
}
294
283
@@ -321,8 +310,6 @@ public void test3p3UnexpectedErrorDoesNotClearCache() {
321
310
}
322
311
}
323
312
324
- // TODO-OIDC reinstate 2 broken(?) tests in mongodb-oidc-no-retry.json
325
-
326
313
@ Test
327
314
public void test4p1Reauthentication () {
328
315
TestCallback callback = createCallback ();
@@ -335,6 +322,59 @@ public void test4p1Reauthentication() {
335
322
assertEquals (2 , callback .invocations .get ());
336
323
}
337
324
325
+ @ Test
326
+ public void test4p2ReadCommandsFailIfReauthenticationFails () {
327
+ // Create a `MongoClient` whose OIDC callback returns one good token
328
+ // and then bad tokens after the first call.
329
+ TestCallback wrappedCallback = createCallback ();
330
+ OidcCallback callback = (context ) -> {
331
+ OidcCallbackResult result1 = wrappedCallback .callback (context );
332
+ return new OidcCallbackResult (
333
+ wrappedCallback .getInvocations () > 1 ? "bad" : result1 .getAccessToken (),
334
+ Duration .ZERO ,
335
+ null );
336
+ };
337
+ MongoClientSettings clientSettings = createSettings (callback );
338
+ try (MongoClient mongoClient = createMongoClient (clientSettings )) {
339
+ performFind (mongoClient );
340
+ failCommand (391 , 1 , "find" );
341
+ assertCause (MongoCommandException .class ,
342
+ "Command failed with error 18" ,
343
+ () -> performFind (mongoClient ));
344
+ }
345
+ assertEquals (2 , wrappedCallback .invocations .get ());
346
+ }
347
+
348
+ @ Test
349
+ public void test4p3WriteCommandsFailIfReauthenticationFails () {
350
+ // Create a `MongoClient` whose OIDC callback returns one good token
351
+ // and then bad tokens after the first call.
352
+ TestCallback wrappedCallback = createCallback ();
353
+ OidcCallback callback = (context ) -> {
354
+ OidcCallbackResult result1 = wrappedCallback .callback (context );
355
+ return new OidcCallbackResult (
356
+ wrappedCallback .getInvocations () > 1 ? "bad" : result1 .getAccessToken (),
357
+ Duration .ZERO ,
358
+ null );
359
+ };
360
+ MongoClientSettings clientSettings = createSettings (callback );
361
+ try (MongoClient mongoClient = createMongoClient (clientSettings )) {
362
+ performInsert (mongoClient );
363
+ failCommand (391 , 1 , "insert" );
364
+ assertCause (MongoCommandException .class ,
365
+ "Command failed with error 18" ,
366
+ () -> performInsert (mongoClient ));
367
+ }
368
+ assertEquals (2 , wrappedCallback .invocations .get ());
369
+ }
370
+
371
+ private static void performInsert (final MongoClient mongoClient ) {
372
+ mongoClient
373
+ .getDatabase ("test" )
374
+ .getCollection ("test" )
375
+ .insertOne (Document .parse ("{ x: 1 }" ));
376
+ }
377
+
338
378
@ Test
339
379
public void test5p1Azure () {
340
380
assumeTrue (getOidcEnv ().equals ("azure" ));
@@ -410,7 +450,7 @@ public void testh1p5MultiplePrincipalNoUser() {
410
450
// Create an OIDC configured client with `MONGODB_URI_MULTI` and no username.
411
451
MongoClientSettings clientSettings = createSettingsMulti (null , createHumanCallback ());
412
452
// Assert that a `find` operation fails.
413
- performFind (clientSettings , MongoCommandException .class , "Authentication failed" );
453
+ assertFindFails (clientSettings , MongoCommandException .class , "Authentication failed" );
414
454
}
415
455
416
456
@ Test
@@ -420,15 +460,15 @@ public void testh1p6AllowedHostsBlocked() {
420
460
//- Assert that a ``find`` operation fails with a client-side error.
421
461
MongoClientSettings clientSettings1 = createSettings (getOidcUri (),
422
462
createHumanCallback (), null , OIDC_HUMAN_CALLBACK_KEY , Collections .emptyList ());
423
- performFind (clientSettings1 , MongoSecurityException .class , "not permitted by ALLOWED_HOSTS" );
463
+ assertFindFails (clientSettings1 , MongoSecurityException .class , "not permitted by ALLOWED_HOSTS" );
424
464
425
465
//- Create a client that uses the URL
426
466
// ``mongodb://localhost/?authMechanism=MONGODB-OIDC&ignored=example.com``, a
427
467
// human callback, and an ``ALLOWED_HOSTS`` that contains ``["example.com"]``.
428
468
//- Assert that a ``find`` operation fails with a client-side error.
429
469
MongoClientSettings clientSettings2 = createSettings (getOidcUri () + "&ignored=example.com" ,
430
470
createHumanCallback (), null , OIDC_HUMAN_CALLBACK_KEY , Arrays .asList ("example.com" ));
431
- performFind (clientSettings2 , MongoSecurityException .class , "not permitted by ALLOWED_HOSTS" );
471
+ assertFindFails (clientSettings2 , MongoSecurityException .class , "not permitted by ALLOWED_HOSTS" );
432
472
}
433
473
434
474
// Not a prose test
@@ -485,14 +525,14 @@ public void testh2p2HumanCallbackReturnsMissingData() {
485
525
assumeTestEnvironment ();
486
526
//noinspection ConstantConditions
487
527
OidcCallback callbackNull = (context ) -> null ;
488
- performFind (createHumanSettings (callbackNull , null ),
528
+ assertFindFails (createHumanSettings (callbackNull , null ),
489
529
MongoConfigurationException .class ,
490
530
"Result of callback must not be null" );
491
531
492
532
//noinspection ConstantConditions
493
533
OidcCallback callback =
494
534
(context ) -> new OidcCallbackResult (null , Duration .ZERO );
495
- performFind (createHumanSettings (callback , null ),
535
+ assertFindFails (createHumanSettings (callback , null ),
496
536
IllegalArgumentException .class ,
497
537
"accessToken can not be null" );
498
538
}
@@ -503,7 +543,7 @@ public void testRefreshTokenAbsent() {
503
543
// additionally, check validation for refresh in machine workflow:
504
544
OidcCallback callbackMachineRefresh =
505
545
(context ) -> new OidcCallbackResult ("access" , Duration .ZERO , "exists" );
506
- performFind (createSettings (callbackMachineRefresh ),
546
+ assertFindFails (createSettings (callbackMachineRefresh ),
507
547
MongoConfigurationException .class ,
508
548
"Refresh token must only be provided in human workflow" );
509
549
}
@@ -549,7 +589,7 @@ public void testh3p2NoSpecAuthIfNoCachedToken() {
549
589
failCommand (18 , 1 , "saslStart" );
550
590
TestListener listener = new TestListener ();
551
591
TestCommandListener commandListener = new TestCommandListener (listener );
552
- performFind (createHumanSettings (createHumanCallback (), commandListener ),
592
+ assertFindFails (createHumanSettings (createHumanCallback (), commandListener ),
553
593
MongoCommandException .class ,
554
594
"Command failed with error 18" );
555
595
assertEquals (Arrays .asList (
@@ -833,7 +873,7 @@ private void performFind(final MongoClientSettings settings) {
833
873
}
834
874
}
835
875
836
- private <T extends Throwable > void performFind (
876
+ private <T extends Throwable > void assertFindFails (
837
877
final MongoClientSettings settings ,
838
878
final Class <T > expectedExceptionOrCause ,
839
879
final String expectedMessage ) {
@@ -852,21 +892,15 @@ private void performFind(final MongoClient mongoClient) {
852
892
853
893
private static <T extends Throwable > void assertCause (
854
894
final Class <T > expectedCause , final String expectedMessageFragment , final Executable e ) {
855
- Throwable actualException = assertThrows (Throwable .class , e );
856
- assertCause (expectedCause , expectedMessageFragment , actualException );
857
- }
858
-
859
- private static <T extends Throwable > void assertCause (
860
- final Class <T > expectedCause , final String expectedMessageFragment , final Throwable actualException ) {
861
- Throwable cause = actualException ;
895
+ Throwable cause = assertThrows (Throwable .class , e );
862
896
while (cause .getCause () != null ) {
863
897
cause = cause .getCause ();
864
898
}
865
899
if (!expectedCause .isInstance (cause )) {
866
- throw new AssertionFailedError ("Unexpected cause: " + actualException . getClass (), actualException );
900
+ throw new AssertionFailedError ("Unexpected cause: " + assertThrows ( Throwable . class , e ). getClass (), assertThrows ( Throwable . class , e ) );
867
901
}
868
902
if (!cause .getMessage ().contains (expectedMessageFragment )) {
869
- throw new AssertionFailedError ("Unexpected message" , actualException );
903
+ throw new AssertionFailedError ("Unexpected message" , assertThrows ( Throwable . class , e ) );
870
904
}
871
905
}
872
906
0 commit comments