diff --git a/src/client/auth.rs b/src/client/auth.rs index d3678fffe..c20fc1d6c 100644 --- a/src/client/auth.rs +++ b/src/client/auth.rs @@ -341,6 +341,44 @@ impl AuthMechanism { .into()), } } + + pub(crate) async fn reauthenticate_stream( + &self, + stream: &mut Connection, + credential: &Credential, + server_api: Option<&ServerApi>, + ) -> Result<()> { + self.validate_credential(credential)?; + + match self { + AuthMechanism::ScramSha1 + | AuthMechanism::ScramSha256 + | AuthMechanism::MongoDbX509 + | AuthMechanism::Plain + | AuthMechanism::MongoDbCr => Err(ErrorKind::Authentication { + message: format!( + "Reauthentication for authentication mechanism {:?} is not supported.", + self + ), + } + .into()), + #[cfg(feature = "aws-auth")] + AuthMechanism::MongoDbAws => Err(ErrorKind::Authentication { + message: format!( + "Reauthentication for authentication mechanism {:?} is not supported.", + self + ), + } + .into()), + AuthMechanism::MongoDbOidc => { + oidc::reauthenticate_stream(stream, credential, server_api).await + } + _ => Err(ErrorKind::Authentication { + message: format!("Authentication mechanism {:?} not yet implemented.", self), + } + .into()), + } + } } impl FromStr for AuthMechanism { diff --git a/src/client/auth/oidc.rs b/src/client/auth/oidc.rs index 11cb01a43..62565d1d0 100644 --- a/src/client/auth/oidc.rs +++ b/src/client/auth/oidc.rs @@ -224,6 +224,15 @@ fn get_refresh_token_and_idp_info( (refresh_token, idp_info) } +pub(crate) async fn reauthenticate_stream( + conn: &mut Connection, + credential: &Credential, + server_api: Option<&ServerApi>, +) -> Result<()> { + invalidate_caches(conn, credential); + authenticate_stream(conn, credential, server_api, None).await +} + pub(crate) async fn authenticate_stream( conn: &mut Connection, credential: &Credential, diff --git a/src/client/executor.rs b/src/client/executor.rs index 0a3b0c249..c3256c550 100644 --- a/src/client/executor.rs +++ b/src/client/executor.rs @@ -280,7 +280,8 @@ impl Client { } /// Selects a server and executes the given operation on it, optionally using a provided - /// session. Retries the operation upon failure if retryability is supported. + /// session. Retries the operation upon failure if retryability is supported or after + /// reauthenticating if reauthentication is required. async fn execute_operation_with_retry( &self, mut op: T, @@ -397,6 +398,30 @@ impl Client { implicit_session, }, Err(mut err) => { + // If the error is a reauthentication required error, we reauthenticate and + // retry the operation. + if err.is_reauthentication_required() { + let credential = self.inner.options.credential.as_ref().ok_or( + ErrorKind::Authentication { + message: "No Credential when reauthentication required error \ + occured" + .to_string(), + }, + )?; + let server_api = self.inner.options.server_api.as_ref(); + + credential + .mechanism + .as_ref() + .ok_or(ErrorKind::Authentication { + message: "No AuthMechanism when reauthentication required error \ + occured" + .to_string(), + })? + .reauthenticate_stream(&mut conn, credential, server_api) + .await?; + continue; + } err.wire_version = conn.stream_description()?.max_wire_version; // Retryable writes are only supported by storage engines with document-level diff --git a/src/error.rs b/src/error.rs index b1dda0098..19e4ceb2d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -26,6 +26,7 @@ const RETRYABLE_WRITE_CODES: [i32; 12] = [ 11600, 11602, 10107, 13435, 13436, 189, 91, 7, 6, 89, 9001, 262, ]; const UNKNOWN_TRANSACTION_COMMIT_RESULT_LABEL_CODES: [i32; 3] = [50, 64, 91]; +const REAUTHENTICATION_REQUIRED_CODE: i32 = 391; /// Retryable write error label. This label will be added to an error when the error is /// write-retryable. @@ -378,6 +379,11 @@ impl Error { .unwrap_or(false) } + /// If this error corresponds to a "reauthentication required" error. + pub(crate) fn is_reauthentication_required(&self) -> bool { + self.sdam_code() == Some(REAUTHENTICATION_REQUIRED_CODE) + } + /// If this error corresponds to a "node is recovering" error as per the SDAM spec. pub(crate) fn is_recovering(&self) -> bool { self.sdam_code() diff --git a/src/test/spec/oidc.rs b/src/test/spec/oidc.rs index d886264ab..bfb55a6f8 100644 --- a/src/test/spec/oidc.rs +++ b/src/test/spec/oidc.rs @@ -1,14 +1,17 @@ +use crate::{ + client::{ + auth::{oidc, AuthMechanism, Credential}, + options::ClientOptions, + }, + test::log_uncaptured, + Client, +}; +use std::sync::{Arc, Mutex}; + +// Machine Callback tests // Prose test 1.1 Single Principal Implicit Username #[tokio::test] -async fn single_principal_implicit_username() -> anyhow::Result<()> { - use crate::{ - client::{ - auth::{oidc, AuthMechanism, Credential}, - options::ClientOptions, - }, - test::log_uncaptured, - Client, - }; +async fn machine_single_principal_implicit_username() -> anyhow::Result<()> { use bson::Document; use futures_util::FutureExt; @@ -17,10 +20,16 @@ async fn single_principal_implicit_username() -> anyhow::Result<()> { return Ok(()); } + // we need to assert that the callback is only called once + let call_count = Arc::new(Mutex::new(0)); + let cb_call_count = call_count.clone(); + let mut opts = ClientOptions::parse("mongodb://localhost/?authMechanism=MONGODB-OIDC").await?; opts.credential = Credential::builder() .mechanism(AuthMechanism::MongoDbOidc) - .oidc_callback(oidc::Callback::machine(|_| { + .oidc_callback(oidc::Callback::machine(move |_| { + let call_count = cb_call_count.clone(); + *call_count.lock().unwrap() += 1; async move { Ok(oidc::IdpServerResponse { access_token: tokio::fs::read_to_string("/tmp/tokens/test_user1").await?, @@ -38,14 +47,14 @@ async fn single_principal_implicit_username() -> anyhow::Result<()> { .collection::("test") .find_one(None, None) .await?; + assert_eq!(1, *(*call_count).lock().unwrap()); Ok(()) } -// TODO RUST-1497: The following test will be removed because it is not an actual test in the spec, -// but just showing that the human flow is still working for two_step (nothing in caching is -// correctly exercised here) +// Human Callback tests +// Prose test 1.1 Single Principal Implicit Username #[tokio::test] -async fn human_flow() -> anyhow::Result<()> { +async fn human_single_principal_implicit_username() -> anyhow::Result<()> { use crate::{ client::{ auth::{oidc, AuthMechanism, Credential}, @@ -62,10 +71,16 @@ async fn human_flow() -> anyhow::Result<()> { return Ok(()); } + // we need to assert that the callback is only called once + let call_count = Arc::new(Mutex::new(0)); + let cb_call_count = call_count.clone(); + let mut opts = ClientOptions::parse("mongodb://localhost/?authMechanism=MONGODB-OIDC").await?; opts.credential = Credential::builder() .mechanism(AuthMechanism::MongoDbOidc) - .oidc_callback(oidc::Callback::human(|_| { + .oidc_callback(oidc::Callback::human(move |_| { + let call_count = cb_call_count.clone(); + *call_count.lock().unwrap() += 1; async move { Ok(oidc::IdpServerResponse { access_token: tokio::fs::read_to_string("/tmp/tokens/test_user1").await?, @@ -83,5 +98,6 @@ async fn human_flow() -> anyhow::Result<()> { .collection::("test") .find_one(None, None) .await?; + assert_eq!(1, *(*call_count).lock().unwrap()); Ok(()) }