Skip to content

RUST-2001 Allow SRV hostnames with less than three parts #1211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 22 additions & 35 deletions src/client/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,15 +265,6 @@ impl ServerAddress {
Self::Unix { path } => path.to_string_lossy(),
}
}

#[cfg(feature = "dns-resolver")]
pub(crate) fn port(&self) -> Option<u16> {
match self {
Self::Tcp { port, .. } => *port,
#[cfg(unix)]
Self::Unix { .. } => None,
}
}
}

impl fmt::Display for ServerAddress {
Expand Down Expand Up @@ -1455,39 +1446,35 @@ impl ConnectionString {
host_list.push(address);
}

let hosts = if srv {
if host_list.len() != 1 {
return Err(ErrorKind::InvalidArgument {
message: "exactly one host must be specified with 'mongodb+srv'".into(),
}
.into());
}

// Unwrap safety: the `len` check above guarantees this can't fail.
match host_list.into_iter().next().unwrap() {
ServerAddress::Tcp { host, port } => {
if port.is_some() {
return Err(ErrorKind::InvalidArgument {
message: "a port cannot be specified with 'mongodb+srv'".into(),
}
.into());
}
HostInfo::DnsRecord(host)
let host_info = if !srv {
Copy link
Contributor Author

@abr-egn abr-egn Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't directly related to the change, but when I was reading through to understand the original code I had a hard time skimming this section. This changes it to put the success cases up front and have the error cases nicely enumerated.

HostInfo::HostIdentifiers(host_list)
} else {
match &host_list[..] {
[ServerAddress::Tcp { host, port: None }] => HostInfo::DnsRecord(host.clone()),
[ServerAddress::Tcp {
host: _,
port: Some(_),
}] => {
return Err(Error::invalid_argument(
"a port cannot be specified with 'mongodb+srv'",
));
}
#[cfg(unix)]
ServerAddress::Unix { .. } => {
return Err(ErrorKind::InvalidArgument {
message: "unix sockets cannot be used with 'mongodb+srv'".into(),
}
.into());
[ServerAddress::Unix { .. }] => {
return Err(Error::invalid_argument(
"unix sockets cannot be used with 'mongodb+srv'",
));
}
_ => {
return Err(Error::invalid_argument(
"exactly one host must be specified with 'mongodb+srv'",
))
}
}
} else {
HostInfo::HostIdentifiers(host_list)
};

let mut conn_str = ConnectionString {
host_info: hosts,
host_info,
#[cfg(test)]
original_uri: s.into(),
..Default::default()
Expand Down
158 changes: 77 additions & 81 deletions src/srv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,62 @@ pub(crate) struct LookupHosts {
pub(crate) min_ttl: Duration,
}

impl LookupHosts {
pub(crate) fn validate(mut self, original_hostname: &str, dm: DomainMismatch) -> Result<Self> {
let original_hostname_parts: Vec<_> = original_hostname.split('.').collect();
let original_domain_name = if original_hostname_parts.len() >= 3 {
&original_hostname_parts[1..]
} else {
&original_hostname_parts[..]
};

let mut ok_hosts = vec![];
for addr in self.hosts.drain(..) {
let host = addr.host();
let hostname_parts: Vec<_> = host.split('.').collect();
if hostname_parts[1..].ends_with(original_domain_name) {
ok_hosts.push(addr);
} else {
let message = format!(
"SRV lookup for {} returned result {}, which does not match domain name {}",
original_hostname,
host,
original_domain_name.join(".")
);
match dm {
DomainMismatch::Error => return Err(ErrorKind::DnsResolve { message }.into()),
DomainMismatch::Skip => {
#[cfg(feature = "tracing-unstable")]
{
use crate::trace::SERVER_SELECTION_TRACING_EVENT_TARGET;
if crate::trace::trace_or_log_enabled!(
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
crate::trace::TracingOrLogLevel::Warn
) {
tracing::warn!(
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
message,
);
}
}
continue;
}
}
}
}
self.hosts = ok_hosts;

if self.hosts.is_empty() {
return Err(ErrorKind::DnsResolve {
message: format!("SRV lookup for {} returned no records", original_hostname),
}
.into());
}

Ok(self)
}
}

#[derive(Debug, Clone, PartialEq)]
pub(crate) struct OriginalSrvInfo {
pub(crate) hostname: String,
Expand Down Expand Up @@ -62,103 +118,43 @@ impl SrvResolver {
Ok(config)
}

pub(crate) async fn get_srv_hosts(
&self,
original_hostname: &str,
dm: DomainMismatch,
) -> Result<LookupHosts> {
async fn get_srv_hosts_unvalidated(&self, lookup_hostname: &str) -> Result<LookupHosts> {
use hickory_proto::rr::RData;

let hostname_parts: Vec<_> = original_hostname.split('.').collect();

if hostname_parts.len() < 3 {
return Err(ErrorKind::InvalidArgument {
message: "a 'mongodb+srv' hostname must have at least three '.'-delimited parts"
.into(),
}
.into());
}

let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname);

let srv_lookup = self.resolver.srv_lookup(lookup_hostname.as_str()).await?;
let mut srv_addresses: Vec<ServerAddress> = Vec::new();
let srv_lookup = self.resolver.srv_lookup(lookup_hostname).await?;
let mut hosts = vec![];
let mut min_ttl = u32::MAX;

for record in srv_lookup.as_lookup().record_iter() {
let srv = match record.data() {
Some(RData::SRV(s)) => s,
_ => continue,
};

let hostname = srv.target().to_utf8();
let port = Some(srv.port());
let mut address = ServerAddress::Tcp {
host: hostname,
port,
};

let domain_name = &hostname_parts[1..];

let host = address.host();
let mut hostname_parts: Vec<_> = host.split('.').collect();

// Remove empty final section, which indicates a trailing dot.
if hostname_parts.last().map(|s| s.is_empty()).unwrap_or(false) {
hostname_parts.pop();
}

if !&hostname_parts[1..].ends_with(domain_name) {
let message = format!(
"SRV lookup for {} returned result {}, which does not match domain name {}",
original_hostname,
address,
domain_name.join(".")
);
if matches!(dm, DomainMismatch::Error) {
return Err(ErrorKind::DnsResolve { message }.into());
} else {
#[cfg(feature = "tracing-unstable")]
{
use crate::trace::SERVER_SELECTION_TRACING_EVENT_TARGET;
if crate::trace::trace_or_log_enabled!(
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
crate::trace::TracingOrLogLevel::Warn
) {
tracing::warn!(
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
message,
);
}
}
}
continue;
let mut host = srv.target().to_utf8();
// Remove the trailing '.'
if host.ends_with('.') {
host.pop();
}

// The spec tests list the seeds without the trailing '.', so we remove it by
// joining the parts we split rather than manipulating the string.
address = ServerAddress::Tcp {
host: hostname_parts.join("."),
port: address.port(),
};

let port = Some(srv.port());
hosts.push(ServerAddress::Tcp { host, port });
min_ttl = std::cmp::min(min_ttl, record.ttl());
srv_addresses.push(address);
}

if srv_addresses.is_empty() {
return Err(ErrorKind::DnsResolve {
message: format!("SRV lookup for {} returned no records", original_hostname),
}
.into());
}

Ok(LookupHosts {
hosts: srv_addresses,
hosts,
min_ttl: Duration::from_secs(min_ttl.into()),
})
}

pub(crate) async fn get_srv_hosts(
&self,
original_hostname: &str,
dm: DomainMismatch,
) -> Result<LookupHosts> {
let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname);
self.get_srv_hosts_unvalidated(&lookup_hostname)
.await?
.validate(original_hostname, dm)
}

async fn get_txt_options(
&self,
original_hostname: &str,
Expand Down
45 changes: 44 additions & 1 deletion src/test/spec/initial_dns_seedlist_discovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use serde::Deserialize;
use crate::{
bson::doc,
client::Client,
options::{ClientOptions, ResolverConfig},
options::{ClientOptions, ResolverConfig, ServerAddress},
srv::{DomainMismatch, LookupHosts},
test::{get_client_options, log_uncaptured, run_spec_test},
};

Expand Down Expand Up @@ -255,3 +256,45 @@ async fn sharded() {
}
run_spec_test(&["initial-dns-seedlist-discovery", "sharded"], run_test).await;
}

fn validate_srv(original: &str, resolved: &str) -> crate::error::Result<()> {
LookupHosts {
hosts: vec![ServerAddress::Tcp {
host: resolved.to_string(),
port: Some(42),
}],
min_ttl: Duration::from_secs(1),
}
.validate(original, DomainMismatch::Error)
.map(|_| ())
}

// Prose test 1. Allow SRVs with fewer than 3 `.` separated parts
#[test]
fn short_srv_domains_valid() {
validate_srv("localhost", "test.localhost").unwrap();
validate_srv("mongo.local", "test.mongo.local").unwrap();
}

// Prose test 2. Throw when return address does not end with SRV domain
#[test]
fn short_srv_domains_invalid_end() {
assert!(validate_srv("localhost", "localhost.mongodb").is_err());
assert!(validate_srv("mongo.local", "test_1.evil.local").is_err());
assert!(validate_srv("blogs.mongodb.com", "blogs.evil.com").is_err());
}

// Prose test 3. Throw when return address is identical to SRV hostname
#[test]
fn short_srv_domains_invalid_identical() {
assert!(validate_srv("localhost", "localhost").is_err());
assert!(validate_srv("mongo.local", "mongo.local").is_err());
}

// Prose test 4. Throw when return address does not contain `.` separating shared part of domain
#[test]
fn short_srv_domains_invalid_no_dot() {
assert!(validate_srv("localhost", "test_1.cluster_1localhost").is_err());
assert!(validate_srv("mongo.local", "test_1.my_hostmongo.local").is_err());
assert!(validate_srv("blogs.mongodb.com", "cluster.testmongodb.com").is_err());
}