From d258305c7b7abef0f7677505ba260480c756f1d3 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Tue, 1 Oct 2024 11:54:51 -0400 Subject: [PATCH 1/6] refactor for testability --- src/client/options.rs | 56 ++++++--------- src/srv.rs | 160 +++++++++++++++++++++--------------------- 2 files changed, 100 insertions(+), 116 deletions(-) diff --git a/src/client/options.rs b/src/client/options.rs index 340dcd988..c4c7c7b31 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -265,15 +265,6 @@ impl ServerAddress { Self::Unix { path } => path.to_string_lossy(), } } - - #[cfg(feature = "dns-resolver")] - pub(crate) fn port(&self) -> Option { - match self { - Self::Tcp { port, .. } => *port, - #[cfg(unix)] - Self::Unix { .. } => None, - } - } } impl fmt::Display for ServerAddress { @@ -1455,39 +1446,34 @@ impl ConnectionString { host_list.push(address); } - let hosts = if srv { - if host_list.len() != 1 { - return Err(ErrorKind::InvalidArgument { - message: "exactly one host must be specified with 'mongodb+srv'".into(), + let host_info = if !srv { + HostInfo::HostIdentifiers(host_list) + } else { + match &host_list[..] { + [ServerAddress::Tcp { host, port: None }] => HostInfo::DnsRecord(host.clone()), + [ServerAddress::Tcp { + host: _, + port: Some(_), + }] => { + return Err(Error::invalid_argument( + "a port cannot be specified with 'mongodb+srv'", + )); } - .into()); - } - - // Unwrap safety: the `len` check above guarantees this can't fail. - match host_list.into_iter().next().unwrap() { - ServerAddress::Tcp { host, port } => { - if port.is_some() { - return Err(ErrorKind::InvalidArgument { - message: "a port cannot be specified with 'mongodb+srv'".into(), - } - .into()); - } - HostInfo::DnsRecord(host) + [ServerAddress::Unix { .. }] => { + return Err(Error::invalid_argument( + "unix sockets cannot be used with 'mongodb+srv'", + )); } - #[cfg(unix)] - ServerAddress::Unix { .. } => { - return Err(ErrorKind::InvalidArgument { - message: "unix sockets cannot be used with 'mongodb+srv'".into(), - } - .into()); + _ => { + return Err(Error::invalid_argument( + "exactly one host must be specified with 'mongodb+srv'", + )) } } - } else { - HostInfo::HostIdentifiers(host_list) }; let mut conn_str = ConnectionString { - host_info: hosts, + host_info, #[cfg(test)] original_uri: s.into(), ..Default::default() diff --git a/src/srv.rs b/src/srv.rs index fed8ba0de..ffb1a003e 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -19,6 +19,64 @@ pub(crate) struct LookupHosts { pub(crate) min_ttl: Duration, } +impl LookupHosts { + fn validate(&mut self, original_hostname: &str, dm: DomainMismatch) -> Result<()> { + let original_hostname_parts: Vec<_> = original_hostname.split('.').collect(); + let original_domain_name = if original_hostname_parts.len() >= 3 { + &original_hostname_parts[1..] + } else { + &original_hostname_parts[..] + }; + + let mut ok_hosts = vec![]; + // TODO: + // * validate conditional n+1 hostname segment length + for addr in self.hosts.drain(..) { + let host = addr.host(); + let hostname_parts: Vec<_> = host.split('.').collect(); + if hostname_parts.ends_with(original_domain_name) { + ok_hosts.push(addr); + } else { + let message = format!( + "SRV lookup for {} returned result {}, which does not match domain name {}", + original_hostname, + host, + original_domain_name.join(".") + ); + match dm { + DomainMismatch::Error => return Err(ErrorKind::DnsResolve { message }.into()), + DomainMismatch::Skip => { + #[cfg(feature = "tracing-unstable")] + { + use crate::trace::SERVER_SELECTION_TRACING_EVENT_TARGET; + if crate::trace::trace_or_log_enabled!( + target: SERVER_SELECTION_TRACING_EVENT_TARGET, + crate::trace::TracingOrLogLevel::Warn + ) { + tracing::warn!( + target: SERVER_SELECTION_TRACING_EVENT_TARGET, + message, + ); + } + } + continue; + } + } + } + } + self.hosts = ok_hosts; + + if self.hosts.is_empty() { + return Err(ErrorKind::DnsResolve { + message: format!("SRV lookup for {} returned no records", original_hostname), + } + .into()); + } + + Ok(()) + } +} + #[derive(Debug, Clone, PartialEq)] pub(crate) struct OriginalSrvInfo { pub(crate) hostname: String, @@ -62,103 +120,43 @@ impl SrvResolver { Ok(config) } - pub(crate) async fn get_srv_hosts( - &self, - original_hostname: &str, - dm: DomainMismatch, - ) -> Result { + async fn get_srv_hosts_unvalidated(&self, lookup_hostname: &str) -> Result { use hickory_proto::rr::RData; - let hostname_parts: Vec<_> = original_hostname.split('.').collect(); - - if hostname_parts.len() < 3 { - return Err(ErrorKind::InvalidArgument { - message: "a 'mongodb+srv' hostname must have at least three '.'-delimited parts" - .into(), - } - .into()); - } - - let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname); - - let srv_lookup = self.resolver.srv_lookup(lookup_hostname.as_str()).await?; - let mut srv_addresses: Vec = Vec::new(); + let srv_lookup = self.resolver.srv_lookup(lookup_hostname).await?; + let mut hosts = vec![]; let mut min_ttl = u32::MAX; - for record in srv_lookup.as_lookup().record_iter() { let srv = match record.data() { Some(RData::SRV(s)) => s, _ => continue, }; - - let hostname = srv.target().to_utf8(); - let port = Some(srv.port()); - let mut address = ServerAddress::Tcp { - host: hostname, - port, - }; - - let domain_name = &hostname_parts[1..]; - - let host = address.host(); - let mut hostname_parts: Vec<_> = host.split('.').collect(); - - // Remove empty final section, which indicates a trailing dot. - if hostname_parts.last().map(|s| s.is_empty()).unwrap_or(false) { - hostname_parts.pop(); - } - - if !&hostname_parts[1..].ends_with(domain_name) { - let message = format!( - "SRV lookup for {} returned result {}, which does not match domain name {}", - original_hostname, - address, - domain_name.join(".") - ); - if matches!(dm, DomainMismatch::Error) { - return Err(ErrorKind::DnsResolve { message }.into()); - } else { - #[cfg(feature = "tracing-unstable")] - { - use crate::trace::SERVER_SELECTION_TRACING_EVENT_TARGET; - if crate::trace::trace_or_log_enabled!( - target: SERVER_SELECTION_TRACING_EVENT_TARGET, - crate::trace::TracingOrLogLevel::Warn - ) { - tracing::warn!( - target: SERVER_SELECTION_TRACING_EVENT_TARGET, - message, - ); - } - } - } - continue; + let mut host = srv.target().to_utf8(); + // Remove the trailing '.' + if host.ends_with('.') { + host.pop(); } - - // The spec tests list the seeds without the trailing '.', so we remove it by - // joining the parts we split rather than manipulating the string. - address = ServerAddress::Tcp { - host: hostname_parts.join("."), - port: address.port(), - }; - + let port = Some(srv.port()); + hosts.push(ServerAddress::Tcp { host, port }); min_ttl = std::cmp::min(min_ttl, record.ttl()); - srv_addresses.push(address); - } - - if srv_addresses.is_empty() { - return Err(ErrorKind::DnsResolve { - message: format!("SRV lookup for {} returned no records", original_hostname), - } - .into()); } - Ok(LookupHosts { - hosts: srv_addresses, + hosts, min_ttl: Duration::from_secs(min_ttl.into()), }) } + pub(crate) async fn get_srv_hosts( + &self, + original_hostname: &str, + dm: DomainMismatch, + ) -> Result { + let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname); + let mut lookup_hosts = self.get_srv_hosts_unvalidated(&lookup_hostname).await?; + lookup_hosts.validate(original_hostname, dm)?; + Ok(lookup_hosts) + } + async fn get_txt_options( &self, original_hostname: &str, From c1617f4878a1326ebbf4f13c96f5de177aa7e8b4 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Tue, 1 Oct 2024 11:58:42 -0400 Subject: [PATCH 2/6] length check --- src/srv.rs | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/srv.rs b/src/srv.rs index ffb1a003e..8669d3558 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -22,19 +22,22 @@ pub(crate) struct LookupHosts { impl LookupHosts { fn validate(&mut self, original_hostname: &str, dm: DomainMismatch) -> Result<()> { let original_hostname_parts: Vec<_> = original_hostname.split('.').collect(); - let original_domain_name = if original_hostname_parts.len() >= 3 { - &original_hostname_parts[1..] + let (original_domain_name, min_len) = if original_hostname_parts.len() >= 3 { + (&original_hostname_parts[1..], 3) } else { - &original_hostname_parts[..] + ( + &original_hostname_parts[..], + original_hostname_parts.len() + 1, + ) }; let mut ok_hosts = vec![]; - // TODO: - // * validate conditional n+1 hostname segment length for addr in self.hosts.drain(..) { let host = addr.host(); let hostname_parts: Vec<_> = host.split('.').collect(); - if hostname_parts.ends_with(original_domain_name) { + let hostname_matches = hostname_parts.ends_with(original_domain_name); + let length_ok = hostname_parts.len() >= min_len; + if hostname_matches && length_ok { ok_hosts.push(addr); } else { let message = format!( From 8972a916a320a29e53efc3640adb9a62999c6f92 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Tue, 1 Oct 2024 14:29:50 -0400 Subject: [PATCH 3/6] tests --- src/srv.rs | 2 +- .../spec/initial_dns_seedlist_discovery.rs | 44 ++++++++++++++++++- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/srv.rs b/src/srv.rs index 8669d3558..f3dcf4762 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -20,7 +20,7 @@ pub(crate) struct LookupHosts { } impl LookupHosts { - fn validate(&mut self, original_hostname: &str, dm: DomainMismatch) -> Result<()> { + pub(crate) fn validate(&mut self, original_hostname: &str, dm: DomainMismatch) -> Result<()> { let original_hostname_parts: Vec<_> = original_hostname.split('.').collect(); let (original_domain_name, min_len) = if original_hostname_parts.len() >= 3 { (&original_hostname_parts[1..], 3) diff --git a/src/test/spec/initial_dns_seedlist_discovery.rs b/src/test/spec/initial_dns_seedlist_discovery.rs index ec2822ec4..bbf6d92c8 100644 --- a/src/test/spec/initial_dns_seedlist_discovery.rs +++ b/src/test/spec/initial_dns_seedlist_discovery.rs @@ -5,7 +5,8 @@ use serde::Deserialize; use crate::{ bson::doc, client::Client, - options::{ClientOptions, ResolverConfig}, + options::{ClientOptions, ResolverConfig, ServerAddress}, + srv::{DomainMismatch, LookupHosts}, test::{get_client_options, log_uncaptured, run_spec_test}, }; @@ -255,3 +256,44 @@ async fn sharded() { } run_spec_test(&["initial-dns-seedlist-discovery", "sharded"], run_test).await; } + +fn validate_srv(original: &str, resolved: &str) -> crate::error::Result<()> { + let mut lh = LookupHosts { + hosts: vec![ServerAddress::Tcp { + host: resolved.to_string(), + port: Some(42), + }], + min_ttl: Duration::from_secs(1), + }; + lh.validate(original, DomainMismatch::Error) +} + +// Prose test 1. Allow SRVs with fewer than 3 `.` separated parts +#[test] +fn short_srv_domains_valid() { + validate_srv("localhost", "test.localhost").unwrap(); + validate_srv("mongo.local", "test.mongo.local").unwrap(); +} + +// Prose test 2. Throw when return address does not end with SRV domain +#[test] +fn short_srv_domains_invalid_end() { + assert!(validate_srv("localhost", "localhost.mongodb").is_err()); + assert!(validate_srv("mongo.local", "test_1.evil.local").is_err()); + assert!(validate_srv("blogs.mongodb.com", "blogs.evil.com").is_err()); +} + +// Prose test 3. Throw when return address is identical to SRV hostname +#[test] +fn short_srv_domains_invalid_identical() { + assert!(validate_srv("localhost", "localhost").is_err()); + assert!(validate_srv("mongo.local", "mongo.local").is_err()); +} + +// Prose test 4. Throw when return address does not contain `.` separating shared part of domain +#[test] +fn short_srv_domains_invalid_no_dot() { + assert!(validate_srv("localhost", "test_1.cluster_1localhost").is_err()); + assert!(validate_srv("mongo.local", "test_1.my_hostmongo.local").is_err()); + assert!(validate_srv("blogs.mongodb.com", "cluster.testmongodb.com").is_err()); +} From 22edc89a3a25d6a967aa5a3d1d9bee01aceb4787 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Tue, 1 Oct 2024 14:38:36 -0400 Subject: [PATCH 4/6] add back in unix --- src/client/options.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/client/options.rs b/src/client/options.rs index c4c7c7b31..0740e8e77 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -1459,6 +1459,7 @@ impl ConnectionString { "a port cannot be specified with 'mongodb+srv'", )); } + #[cfg(unix)] [ServerAddress::Unix { .. }] => { return Err(Error::invalid_argument( "unix sockets cannot be used with 'mongodb+srv'", From 0d38e7b27d747650265ad8eaf135dfeaee239534 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Wed, 2 Oct 2024 11:10:18 -0400 Subject: [PATCH 5/6] forgot to carry the one --- src/srv.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/srv.rs b/src/srv.rs index f3dcf4762..538e17379 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -22,22 +22,17 @@ pub(crate) struct LookupHosts { impl LookupHosts { pub(crate) fn validate(&mut self, original_hostname: &str, dm: DomainMismatch) -> Result<()> { let original_hostname_parts: Vec<_> = original_hostname.split('.').collect(); - let (original_domain_name, min_len) = if original_hostname_parts.len() >= 3 { - (&original_hostname_parts[1..], 3) + let original_domain_name = if original_hostname_parts.len() >= 3 { + &original_hostname_parts[1..] } else { - ( - &original_hostname_parts[..], - original_hostname_parts.len() + 1, - ) + &original_hostname_parts[..] }; let mut ok_hosts = vec![]; for addr in self.hosts.drain(..) { let host = addr.host(); let hostname_parts: Vec<_> = host.split('.').collect(); - let hostname_matches = hostname_parts.ends_with(original_domain_name); - let length_ok = hostname_parts.len() >= min_len; - if hostname_matches && length_ok { + if hostname_parts[1..].ends_with(original_domain_name) { ok_hosts.push(addr); } else { let message = format!( From 7d8462e0b02433ccdd0598cf8d27f0cabb8cbd59 Mon Sep 17 00:00:00 2001 From: Abraham Egnor Date: Wed, 2 Oct 2024 11:59:25 -0400 Subject: [PATCH 6/6] chain --- src/srv.rs | 10 +++++----- src/test/spec/initial_dns_seedlist_discovery.rs | 7 ++++--- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/srv.rs b/src/srv.rs index 538e17379..d45863495 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -20,7 +20,7 @@ pub(crate) struct LookupHosts { } impl LookupHosts { - pub(crate) fn validate(&mut self, original_hostname: &str, dm: DomainMismatch) -> Result<()> { + pub(crate) fn validate(mut self, original_hostname: &str, dm: DomainMismatch) -> Result { let original_hostname_parts: Vec<_> = original_hostname.split('.').collect(); let original_domain_name = if original_hostname_parts.len() >= 3 { &original_hostname_parts[1..] @@ -71,7 +71,7 @@ impl LookupHosts { .into()); } - Ok(()) + Ok(self) } } @@ -150,9 +150,9 @@ impl SrvResolver { dm: DomainMismatch, ) -> Result { let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname); - let mut lookup_hosts = self.get_srv_hosts_unvalidated(&lookup_hostname).await?; - lookup_hosts.validate(original_hostname, dm)?; - Ok(lookup_hosts) + self.get_srv_hosts_unvalidated(&lookup_hostname) + .await? + .validate(original_hostname, dm) } async fn get_txt_options( diff --git a/src/test/spec/initial_dns_seedlist_discovery.rs b/src/test/spec/initial_dns_seedlist_discovery.rs index bbf6d92c8..b04c1219b 100644 --- a/src/test/spec/initial_dns_seedlist_discovery.rs +++ b/src/test/spec/initial_dns_seedlist_discovery.rs @@ -258,14 +258,15 @@ async fn sharded() { } fn validate_srv(original: &str, resolved: &str) -> crate::error::Result<()> { - let mut lh = LookupHosts { + LookupHosts { hosts: vec![ServerAddress::Tcp { host: resolved.to_string(), port: Some(42), }], min_ttl: Duration::from_secs(1), - }; - lh.validate(original, DomainMismatch::Error) + } + .validate(original, DomainMismatch::Error) + .map(|_| ()) } // Prose test 1. Allow SRVs with fewer than 3 `.` separated parts