diff --git a/src/client/options.rs b/src/client/options.rs index 340dcd988..0740e8e77 100644 --- a/src/client/options.rs +++ b/src/client/options.rs @@ -265,15 +265,6 @@ impl ServerAddress { Self::Unix { path } => path.to_string_lossy(), } } - - #[cfg(feature = "dns-resolver")] - pub(crate) fn port(&self) -> Option { - match self { - Self::Tcp { port, .. } => *port, - #[cfg(unix)] - Self::Unix { .. } => None, - } - } } impl fmt::Display for ServerAddress { @@ -1455,39 +1446,35 @@ impl ConnectionString { host_list.push(address); } - let hosts = if srv { - if host_list.len() != 1 { - return Err(ErrorKind::InvalidArgument { - message: "exactly one host must be specified with 'mongodb+srv'".into(), - } - .into()); - } - - // Unwrap safety: the `len` check above guarantees this can't fail. - match host_list.into_iter().next().unwrap() { - ServerAddress::Tcp { host, port } => { - if port.is_some() { - return Err(ErrorKind::InvalidArgument { - message: "a port cannot be specified with 'mongodb+srv'".into(), - } - .into()); - } - HostInfo::DnsRecord(host) + let host_info = if !srv { + HostInfo::HostIdentifiers(host_list) + } else { + match &host_list[..] { + [ServerAddress::Tcp { host, port: None }] => HostInfo::DnsRecord(host.clone()), + [ServerAddress::Tcp { + host: _, + port: Some(_), + }] => { + return Err(Error::invalid_argument( + "a port cannot be specified with 'mongodb+srv'", + )); } #[cfg(unix)] - ServerAddress::Unix { .. } => { - return Err(ErrorKind::InvalidArgument { - message: "unix sockets cannot be used with 'mongodb+srv'".into(), - } - .into()); + [ServerAddress::Unix { .. }] => { + return Err(Error::invalid_argument( + "unix sockets cannot be used with 'mongodb+srv'", + )); + } + _ => { + return Err(Error::invalid_argument( + "exactly one host must be specified with 'mongodb+srv'", + )) } } - } else { - HostInfo::HostIdentifiers(host_list) }; let mut conn_str = ConnectionString { - host_info: hosts, + host_info, #[cfg(test)] original_uri: s.into(), ..Default::default() diff --git a/src/srv.rs b/src/srv.rs index fed8ba0de..d45863495 100644 --- a/src/srv.rs +++ b/src/srv.rs @@ -19,6 +19,62 @@ pub(crate) struct LookupHosts { pub(crate) min_ttl: Duration, } +impl LookupHosts { + pub(crate) fn validate(mut self, original_hostname: &str, dm: DomainMismatch) -> Result { + let original_hostname_parts: Vec<_> = original_hostname.split('.').collect(); + let original_domain_name = if original_hostname_parts.len() >= 3 { + &original_hostname_parts[1..] + } else { + &original_hostname_parts[..] + }; + + let mut ok_hosts = vec![]; + for addr in self.hosts.drain(..) { + let host = addr.host(); + let hostname_parts: Vec<_> = host.split('.').collect(); + if hostname_parts[1..].ends_with(original_domain_name) { + ok_hosts.push(addr); + } else { + let message = format!( + "SRV lookup for {} returned result {}, which does not match domain name {}", + original_hostname, + host, + original_domain_name.join(".") + ); + match dm { + DomainMismatch::Error => return Err(ErrorKind::DnsResolve { message }.into()), + DomainMismatch::Skip => { + #[cfg(feature = "tracing-unstable")] + { + use crate::trace::SERVER_SELECTION_TRACING_EVENT_TARGET; + if crate::trace::trace_or_log_enabled!( + target: SERVER_SELECTION_TRACING_EVENT_TARGET, + crate::trace::TracingOrLogLevel::Warn + ) { + tracing::warn!( + target: SERVER_SELECTION_TRACING_EVENT_TARGET, + message, + ); + } + } + continue; + } + } + } + } + self.hosts = ok_hosts; + + if self.hosts.is_empty() { + return Err(ErrorKind::DnsResolve { + message: format!("SRV lookup for {} returned no records", original_hostname), + } + .into()); + } + + Ok(self) + } +} + #[derive(Debug, Clone, PartialEq)] pub(crate) struct OriginalSrvInfo { pub(crate) hostname: String, @@ -62,103 +118,43 @@ impl SrvResolver { Ok(config) } - pub(crate) async fn get_srv_hosts( - &self, - original_hostname: &str, - dm: DomainMismatch, - ) -> Result { + async fn get_srv_hosts_unvalidated(&self, lookup_hostname: &str) -> Result { use hickory_proto::rr::RData; - let hostname_parts: Vec<_> = original_hostname.split('.').collect(); - - if hostname_parts.len() < 3 { - return Err(ErrorKind::InvalidArgument { - message: "a 'mongodb+srv' hostname must have at least three '.'-delimited parts" - .into(), - } - .into()); - } - - let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname); - - let srv_lookup = self.resolver.srv_lookup(lookup_hostname.as_str()).await?; - let mut srv_addresses: Vec = Vec::new(); + let srv_lookup = self.resolver.srv_lookup(lookup_hostname).await?; + let mut hosts = vec![]; let mut min_ttl = u32::MAX; - for record in srv_lookup.as_lookup().record_iter() { let srv = match record.data() { Some(RData::SRV(s)) => s, _ => continue, }; - - let hostname = srv.target().to_utf8(); - let port = Some(srv.port()); - let mut address = ServerAddress::Tcp { - host: hostname, - port, - }; - - let domain_name = &hostname_parts[1..]; - - let host = address.host(); - let mut hostname_parts: Vec<_> = host.split('.').collect(); - - // Remove empty final section, which indicates a trailing dot. - if hostname_parts.last().map(|s| s.is_empty()).unwrap_or(false) { - hostname_parts.pop(); - } - - if !&hostname_parts[1..].ends_with(domain_name) { - let message = format!( - "SRV lookup for {} returned result {}, which does not match domain name {}", - original_hostname, - address, - domain_name.join(".") - ); - if matches!(dm, DomainMismatch::Error) { - return Err(ErrorKind::DnsResolve { message }.into()); - } else { - #[cfg(feature = "tracing-unstable")] - { - use crate::trace::SERVER_SELECTION_TRACING_EVENT_TARGET; - if crate::trace::trace_or_log_enabled!( - target: SERVER_SELECTION_TRACING_EVENT_TARGET, - crate::trace::TracingOrLogLevel::Warn - ) { - tracing::warn!( - target: SERVER_SELECTION_TRACING_EVENT_TARGET, - message, - ); - } - } - } - continue; + let mut host = srv.target().to_utf8(); + // Remove the trailing '.' + if host.ends_with('.') { + host.pop(); } - - // The spec tests list the seeds without the trailing '.', so we remove it by - // joining the parts we split rather than manipulating the string. - address = ServerAddress::Tcp { - host: hostname_parts.join("."), - port: address.port(), - }; - + let port = Some(srv.port()); + hosts.push(ServerAddress::Tcp { host, port }); min_ttl = std::cmp::min(min_ttl, record.ttl()); - srv_addresses.push(address); } - - if srv_addresses.is_empty() { - return Err(ErrorKind::DnsResolve { - message: format!("SRV lookup for {} returned no records", original_hostname), - } - .into()); - } - Ok(LookupHosts { - hosts: srv_addresses, + hosts, min_ttl: Duration::from_secs(min_ttl.into()), }) } + pub(crate) async fn get_srv_hosts( + &self, + original_hostname: &str, + dm: DomainMismatch, + ) -> Result { + let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname); + self.get_srv_hosts_unvalidated(&lookup_hostname) + .await? + .validate(original_hostname, dm) + } + async fn get_txt_options( &self, original_hostname: &str, diff --git a/src/test/spec/initial_dns_seedlist_discovery.rs b/src/test/spec/initial_dns_seedlist_discovery.rs index ec2822ec4..b04c1219b 100644 --- a/src/test/spec/initial_dns_seedlist_discovery.rs +++ b/src/test/spec/initial_dns_seedlist_discovery.rs @@ -5,7 +5,8 @@ use serde::Deserialize; use crate::{ bson::doc, client::Client, - options::{ClientOptions, ResolverConfig}, + options::{ClientOptions, ResolverConfig, ServerAddress}, + srv::{DomainMismatch, LookupHosts}, test::{get_client_options, log_uncaptured, run_spec_test}, }; @@ -255,3 +256,45 @@ async fn sharded() { } run_spec_test(&["initial-dns-seedlist-discovery", "sharded"], run_test).await; } + +fn validate_srv(original: &str, resolved: &str) -> crate::error::Result<()> { + LookupHosts { + hosts: vec![ServerAddress::Tcp { + host: resolved.to_string(), + port: Some(42), + }], + min_ttl: Duration::from_secs(1), + } + .validate(original, DomainMismatch::Error) + .map(|_| ()) +} + +// Prose test 1. Allow SRVs with fewer than 3 `.` separated parts +#[test] +fn short_srv_domains_valid() { + validate_srv("localhost", "test.localhost").unwrap(); + validate_srv("mongo.local", "test.mongo.local").unwrap(); +} + +// Prose test 2. Throw when return address does not end with SRV domain +#[test] +fn short_srv_domains_invalid_end() { + assert!(validate_srv("localhost", "localhost.mongodb").is_err()); + assert!(validate_srv("mongo.local", "test_1.evil.local").is_err()); + assert!(validate_srv("blogs.mongodb.com", "blogs.evil.com").is_err()); +} + +// Prose test 3. Throw when return address is identical to SRV hostname +#[test] +fn short_srv_domains_invalid_identical() { + assert!(validate_srv("localhost", "localhost").is_err()); + assert!(validate_srv("mongo.local", "mongo.local").is_err()); +} + +// Prose test 4. Throw when return address does not contain `.` separating shared part of domain +#[test] +fn short_srv_domains_invalid_no_dot() { + assert!(validate_srv("localhost", "test_1.cluster_1localhost").is_err()); + assert!(validate_srv("mongo.local", "test_1.my_hostmongo.local").is_err()); + assert!(validate_srv("blogs.mongodb.com", "cluster.testmongodb.com").is_err()); +}