Skip to content

Commit 2dadbed

Browse files
RUST-1894 Retry KMS requests on transient errors (#1281)
1 parent fac8592 commit 2dadbed

File tree

6 files changed

+223
-33
lines changed

6 files changed

+223
-33
lines changed

.config/nextest.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[profile.default]
22
test-threads = 1
3-
default-filter = 'not test(test::happy_eyeballs)'
3+
default-filter = 'not test(test::happy_eyeballs) and not test(kms_retry)'
44

55
[profile.ci]
66
failure-output = "final"

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,7 @@ time = "0.3.9"
176176
tokio = { version = ">= 0.0.0", features = ["fs", "parking_lot"] }
177177
tracing-subscriber = "0.3.16"
178178
regex = "1.6.0"
179+
reqwest = { version = "0.12.2", features = ["rustls-tls"] }
179180
serde-hex = "0.1.0"
180181
serde_path_to_error = "0.1"
181182

src/client/csfle.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ impl ClientState {
9999
let mut builder = Crypt::builder()
100100
.kms_providers(&opts.kms_providers.credentials_doc()?)?
101101
.use_need_kms_credentials_state()
102+
.retry_kms(true)?
102103
.use_range_v2()?;
103104
if let Some(m) = &opts.schema_map {
104105
builder = builder.schema_map(&bson::to_document(m)?)?;

src/client/csfle/client_encryption.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ impl ClientEncryption {
6565
let crypt = Crypt::builder()
6666
.kms_providers(&kms_providers.credentials_doc()?)?
6767
.use_need_kms_credentials_state()
68+
.retry_kms(true)?
6869
.use_range_v2()?
6970
.build()?;
7071
let exec = CryptExecutor::new_explicit(

src/client/csfle/state_machine.rs

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,20 @@ use std::{
22
convert::TryInto,
33
ops::DerefMut,
44
path::{Path, PathBuf},
5+
time::Duration,
56
};
67

78
use bson::{rawdoc, Document, RawDocument, RawDocumentBuf};
89
use futures_util::{stream, TryStreamExt};
9-
use mongocrypt::ctx::{Ctx, KmsProviderType, State};
10+
use mongocrypt::ctx::{Ctx, KmsCtx, KmsProviderType, State};
1011
use rayon::ThreadPool;
1112
use tokio::{
1213
io::{AsyncReadExt, AsyncWriteExt},
1314
sync::{oneshot, Mutex},
1415
};
1516

1617
use crate::{
17-
client::{options::ServerAddress, WeakClient},
18+
client::{csfle::options::KmsProvidersTlsOptions, options::ServerAddress, WeakClient},
1819
error::{Error, Result},
1920
operation::{run_command::RunCommand, RawOutput},
2021
options::ReadConcern,
@@ -174,37 +175,62 @@ impl CryptExecutor {
174175
State::NeedKms => {
175176
let ctx = result_mut(&mut ctx)?;
176177
let scope = ctx.kms_scope();
177-
let mut kms_ctxen: Vec<Result<_>> = vec![];
178-
while let Some(kms_ctx) = scope.next_kms_ctx() {
179-
kms_ctxen.push(Ok(kms_ctx));
178+
179+
async fn execute(
180+
kms_ctx: &mut KmsCtx<'_>,
181+
tls_options: Option<&KmsProvidersTlsOptions>,
182+
) -> Result<()> {
183+
let endpoint = kms_ctx.endpoint()?;
184+
let addr = ServerAddress::parse(endpoint)?;
185+
let provider = kms_ctx.kms_provider()?;
186+
let tls_options = tls_options
187+
.and_then(|tls| tls.get(&provider))
188+
.cloned()
189+
.unwrap_or_default();
190+
let mut stream =
191+
AsyncStream::connect(addr, Some(&TlsConfig::new(tls_options)?)).await?;
192+
stream.write_all(kms_ctx.message()?).await?;
193+
let mut buf = vec![0];
194+
while kms_ctx.bytes_needed() > 0 {
195+
let buf_size = kms_ctx.bytes_needed().try_into().map_err(|e| {
196+
Error::internal(format!("buffer size overflow: {}", e))
197+
})?;
198+
buf.resize(buf_size, 0);
199+
let count = stream.read(&mut buf).await?;
200+
kms_ctx.feed(&buf[0..count])?;
201+
}
202+
Ok(())
203+
}
204+
205+
loop {
206+
let mut kms_contexts: Vec<Result<_>> = Vec::new();
207+
while let Some(kms_ctx) = scope.next_kms_ctx() {
208+
kms_contexts.push(Ok(kms_ctx));
209+
}
210+
if kms_contexts.is_empty() {
211+
break;
212+
}
213+
214+
stream::iter(kms_contexts)
215+
.try_for_each_concurrent(None, |mut kms_ctx| async move {
216+
let sleep_micros =
217+
u64::try_from(kms_ctx.sleep_micros()).unwrap_or(0);
218+
if sleep_micros > 0 {
219+
tokio::time::sleep(Duration::from_micros(sleep_micros)).await;
220+
}
221+
222+
if let Err(error) =
223+
execute(&mut kms_ctx, self.kms_providers.tls_options()).await
224+
{
225+
if !kms_ctx.retry_failure() {
226+
return Err(error);
227+
}
228+
}
229+
230+
Ok(())
231+
})
232+
.await?;
180233
}
181-
stream::iter(kms_ctxen)
182-
.try_for_each_concurrent(None, |mut kms_ctx| async move {
183-
let endpoint = kms_ctx.endpoint()?;
184-
let addr = ServerAddress::parse(endpoint)?;
185-
let provider = kms_ctx.kms_provider()?;
186-
let tls_options = self
187-
.kms_providers
188-
.tls_options()
189-
.and_then(|tls| tls.get(&provider))
190-
.cloned()
191-
.unwrap_or_default();
192-
let mut stream =
193-
AsyncStream::connect(addr, Some(&TlsConfig::new(tls_options)?))
194-
.await?;
195-
stream.write_all(kms_ctx.message()?).await?;
196-
let mut buf = vec![0];
197-
while kms_ctx.bytes_needed() > 0 {
198-
let buf_size = kms_ctx.bytes_needed().try_into().map_err(|e| {
199-
Error::internal(format!("buffer size overflow: {}", e))
200-
})?;
201-
buf.resize(buf_size, 0);
202-
let count = stream.read(&mut buf).await?;
203-
kms_ctx.feed(&buf[0..count])?;
204-
}
205-
Ok(())
206-
})
207-
.await?;
208234
}
209235
State::NeedKmsCredentials => {
210236
let ctx = result_mut(&mut ctx)?;

src/test/csfle.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3493,6 +3493,167 @@ async fn range_explicit_encryption_defaults() -> Result<()> {
34933493
Ok(())
34943494
}
34953495

3496+
// Prose Test 24. KMS Retry Tests
3497+
#[tokio::test]
3498+
// using openssl causes errors after configuring a network failpoint
3499+
#[cfg(not(feature = "openssl-tls"))]
3500+
async fn kms_retry() {
3501+
use reqwest::{Certificate, Client as HttpClient};
3502+
3503+
let endpoint = "127.0.0.1:9003";
3504+
3505+
let mut certificate_file_path = PathBuf::from(std::env::var("CSFLE_TLS_CERT_DIR").unwrap());
3506+
certificate_file_path.push("ca.pem");
3507+
let certificate_file = std::fs::read(&certificate_file_path).unwrap();
3508+
3509+
let set_failpoint = |kind: &str, count: u8| {
3510+
// create a fresh client for each request to avoid hangs
3511+
let http_client = HttpClient::builder()
3512+
.add_root_certificate(Certificate::from_pem(&certificate_file).unwrap())
3513+
.build()
3514+
.unwrap();
3515+
let url = format!("https://localhost:9003/set_failpoint/{}", kind);
3516+
let body = format!("{{\"count\":{}}}", count);
3517+
http_client.post(url).body(body).send()
3518+
};
3519+
3520+
let aws_kms = AWS_KMS.clone();
3521+
let mut azure_kms = AZURE_KMS.clone();
3522+
azure_kms.1.insert("identityPlatformEndpoint", endpoint);
3523+
let mut gcp_kms = GCP_KMS.clone();
3524+
gcp_kms.1.insert("endpoint", endpoint);
3525+
let mut kms_providers = vec![aws_kms, azure_kms, gcp_kms];
3526+
3527+
let tls_options = get_client_options().await.tls_options();
3528+
for kms_provider in kms_providers.iter_mut() {
3529+
kms_provider.2 = tls_options.clone();
3530+
}
3531+
3532+
let key_vault_client = Client::for_test().await.into_client();
3533+
let client_encryption = ClientEncryption::new(
3534+
key_vault_client,
3535+
Namespace::new("keyvault", "datakeys"),
3536+
kms_providers,
3537+
)
3538+
.unwrap();
3539+
3540+
let aws_master_key = AwsMasterKey::builder()
3541+
.region("foo")
3542+
.key("bar")
3543+
.endpoint(endpoint.to_string())
3544+
.build();
3545+
let azure_master_key = AzureMasterKey::builder()
3546+
.key_vault_endpoint(endpoint)
3547+
.key_name("foo")
3548+
.build();
3549+
let gcp_master_key = GcpMasterKey::builder()
3550+
.project_id("foo")
3551+
.location("bar")
3552+
.key_ring("baz")
3553+
.key_name("qux")
3554+
.endpoint(endpoint.to_string())
3555+
.build();
3556+
3557+
// Case 1: createDataKey and encrypt with TCP retry
3558+
3559+
// AWS
3560+
set_failpoint("network", 1).await.unwrap();
3561+
let key_id = client_encryption
3562+
.create_data_key(aws_master_key.clone())
3563+
.await
3564+
.unwrap();
3565+
set_failpoint("network", 1).await.unwrap();
3566+
client_encryption
3567+
.encrypt(123, key_id, Algorithm::Deterministic)
3568+
.await
3569+
.unwrap();
3570+
3571+
// Azure
3572+
set_failpoint("network", 1).await.unwrap();
3573+
let key_id = client_encryption
3574+
.create_data_key(azure_master_key.clone())
3575+
.await
3576+
.unwrap();
3577+
set_failpoint("network", 1).await.unwrap();
3578+
client_encryption
3579+
.encrypt(123, key_id, Algorithm::Deterministic)
3580+
.await
3581+
.unwrap();
3582+
3583+
// GCP
3584+
set_failpoint("network", 1).await.unwrap();
3585+
let key_id = client_encryption
3586+
.create_data_key(gcp_master_key.clone())
3587+
.await
3588+
.unwrap();
3589+
set_failpoint("network", 1).await.unwrap();
3590+
client_encryption
3591+
.encrypt(123, key_id, Algorithm::Deterministic)
3592+
.await
3593+
.unwrap();
3594+
3595+
// Case 2: createDataKey and encrypt with HTTP retry
3596+
3597+
// AWS
3598+
set_failpoint("http", 1).await.unwrap();
3599+
let key_id = client_encryption
3600+
.create_data_key(aws_master_key.clone())
3601+
.await
3602+
.unwrap();
3603+
set_failpoint("http", 1).await.unwrap();
3604+
client_encryption
3605+
.encrypt(123, key_id, Algorithm::Deterministic)
3606+
.await
3607+
.unwrap();
3608+
3609+
// Azure
3610+
set_failpoint("http", 1).await.unwrap();
3611+
let key_id = client_encryption
3612+
.create_data_key(azure_master_key.clone())
3613+
.await
3614+
.unwrap();
3615+
set_failpoint("http", 1).await.unwrap();
3616+
client_encryption
3617+
.encrypt(123, key_id, Algorithm::Deterministic)
3618+
.await
3619+
.unwrap();
3620+
3621+
// GCP
3622+
set_failpoint("http", 1).await.unwrap();
3623+
let key_id = client_encryption
3624+
.create_data_key(gcp_master_key.clone())
3625+
.await
3626+
.unwrap();
3627+
set_failpoint("http", 1).await.unwrap();
3628+
client_encryption
3629+
.encrypt(123, key_id, Algorithm::Deterministic)
3630+
.await
3631+
.unwrap();
3632+
3633+
// Case 3: createDataKey fails after too many retries
3634+
3635+
// AWS
3636+
set_failpoint("network", 4).await.unwrap();
3637+
client_encryption
3638+
.create_data_key(aws_master_key)
3639+
.await
3640+
.unwrap_err();
3641+
3642+
// Azure
3643+
set_failpoint("network", 4).await.unwrap();
3644+
client_encryption
3645+
.create_data_key(azure_master_key)
3646+
.await
3647+
.unwrap_err();
3648+
3649+
// GCP
3650+
set_failpoint("network", 4).await.unwrap();
3651+
client_encryption
3652+
.create_data_key(gcp_master_key)
3653+
.await
3654+
.unwrap_err();
3655+
}
3656+
34963657
// FLE 2.0 Documentation Example
34973658
#[tokio::test]
34983659
async fn fle2_example() -> Result<()> {

0 commit comments

Comments
 (0)