Skip to content

Commit 7a408c3

Browse files
authored
RUST-2001 Allow SRV hostnames with less than three parts (#1211)
1 parent a6f76b8 commit 7a408c3

File tree

3 files changed

+143
-117
lines changed

3 files changed

+143
-117
lines changed

src/client/options.rs

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,6 @@ impl ServerAddress {
265265
Self::Unix { path } => path.to_string_lossy(),
266266
}
267267
}
268-
269-
#[cfg(feature = "dns-resolver")]
270-
pub(crate) fn port(&self) -> Option<u16> {
271-
match self {
272-
Self::Tcp { port, .. } => *port,
273-
#[cfg(unix)]
274-
Self::Unix { .. } => None,
275-
}
276-
}
277268
}
278269

279270
impl fmt::Display for ServerAddress {
@@ -1455,39 +1446,35 @@ impl ConnectionString {
14551446
host_list.push(address);
14561447
}
14571448

1458-
let hosts = if srv {
1459-
if host_list.len() != 1 {
1460-
return Err(ErrorKind::InvalidArgument {
1461-
message: "exactly one host must be specified with 'mongodb+srv'".into(),
1462-
}
1463-
.into());
1464-
}
1465-
1466-
// Unwrap safety: the `len` check above guarantees this can't fail.
1467-
match host_list.into_iter().next().unwrap() {
1468-
ServerAddress::Tcp { host, port } => {
1469-
if port.is_some() {
1470-
return Err(ErrorKind::InvalidArgument {
1471-
message: "a port cannot be specified with 'mongodb+srv'".into(),
1472-
}
1473-
.into());
1474-
}
1475-
HostInfo::DnsRecord(host)
1449+
let host_info = if !srv {
1450+
HostInfo::HostIdentifiers(host_list)
1451+
} else {
1452+
match &host_list[..] {
1453+
[ServerAddress::Tcp { host, port: None }] => HostInfo::DnsRecord(host.clone()),
1454+
[ServerAddress::Tcp {
1455+
host: _,
1456+
port: Some(_),
1457+
}] => {
1458+
return Err(Error::invalid_argument(
1459+
"a port cannot be specified with 'mongodb+srv'",
1460+
));
14761461
}
14771462
#[cfg(unix)]
1478-
ServerAddress::Unix { .. } => {
1479-
return Err(ErrorKind::InvalidArgument {
1480-
message: "unix sockets cannot be used with 'mongodb+srv'".into(),
1481-
}
1482-
.into());
1463+
[ServerAddress::Unix { .. }] => {
1464+
return Err(Error::invalid_argument(
1465+
"unix sockets cannot be used with 'mongodb+srv'",
1466+
));
1467+
}
1468+
_ => {
1469+
return Err(Error::invalid_argument(
1470+
"exactly one host must be specified with 'mongodb+srv'",
1471+
))
14831472
}
14841473
}
1485-
} else {
1486-
HostInfo::HostIdentifiers(host_list)
14871474
};
14881475

14891476
let mut conn_str = ConnectionString {
1490-
host_info: hosts,
1477+
host_info,
14911478
#[cfg(test)]
14921479
original_uri: s.into(),
14931480
..Default::default()

src/srv.rs

Lines changed: 77 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,62 @@ pub(crate) struct LookupHosts {
1919
pub(crate) min_ttl: Duration,
2020
}
2121

22+
impl LookupHosts {
23+
pub(crate) fn validate(mut self, original_hostname: &str, dm: DomainMismatch) -> Result<Self> {
24+
let original_hostname_parts: Vec<_> = original_hostname.split('.').collect();
25+
let original_domain_name = if original_hostname_parts.len() >= 3 {
26+
&original_hostname_parts[1..]
27+
} else {
28+
&original_hostname_parts[..]
29+
};
30+
31+
let mut ok_hosts = vec![];
32+
for addr in self.hosts.drain(..) {
33+
let host = addr.host();
34+
let hostname_parts: Vec<_> = host.split('.').collect();
35+
if hostname_parts[1..].ends_with(original_domain_name) {
36+
ok_hosts.push(addr);
37+
} else {
38+
let message = format!(
39+
"SRV lookup for {} returned result {}, which does not match domain name {}",
40+
original_hostname,
41+
host,
42+
original_domain_name.join(".")
43+
);
44+
match dm {
45+
DomainMismatch::Error => return Err(ErrorKind::DnsResolve { message }.into()),
46+
DomainMismatch::Skip => {
47+
#[cfg(feature = "tracing-unstable")]
48+
{
49+
use crate::trace::SERVER_SELECTION_TRACING_EVENT_TARGET;
50+
if crate::trace::trace_or_log_enabled!(
51+
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
52+
crate::trace::TracingOrLogLevel::Warn
53+
) {
54+
tracing::warn!(
55+
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
56+
message,
57+
);
58+
}
59+
}
60+
continue;
61+
}
62+
}
63+
}
64+
}
65+
self.hosts = ok_hosts;
66+
67+
if self.hosts.is_empty() {
68+
return Err(ErrorKind::DnsResolve {
69+
message: format!("SRV lookup for {} returned no records", original_hostname),
70+
}
71+
.into());
72+
}
73+
74+
Ok(self)
75+
}
76+
}
77+
2278
#[derive(Debug, Clone, PartialEq)]
2379
pub(crate) struct OriginalSrvInfo {
2480
pub(crate) hostname: String,
@@ -62,103 +118,43 @@ impl SrvResolver {
62118
Ok(config)
63119
}
64120

65-
pub(crate) async fn get_srv_hosts(
66-
&self,
67-
original_hostname: &str,
68-
dm: DomainMismatch,
69-
) -> Result<LookupHosts> {
121+
async fn get_srv_hosts_unvalidated(&self, lookup_hostname: &str) -> Result<LookupHosts> {
70122
use hickory_proto::rr::RData;
71123

72-
let hostname_parts: Vec<_> = original_hostname.split('.').collect();
73-
74-
if hostname_parts.len() < 3 {
75-
return Err(ErrorKind::InvalidArgument {
76-
message: "a 'mongodb+srv' hostname must have at least three '.'-delimited parts"
77-
.into(),
78-
}
79-
.into());
80-
}
81-
82-
let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname);
83-
84-
let srv_lookup = self.resolver.srv_lookup(lookup_hostname.as_str()).await?;
85-
let mut srv_addresses: Vec<ServerAddress> = Vec::new();
124+
let srv_lookup = self.resolver.srv_lookup(lookup_hostname).await?;
125+
let mut hosts = vec![];
86126
let mut min_ttl = u32::MAX;
87-
88127
for record in srv_lookup.as_lookup().record_iter() {
89128
let srv = match record.data() {
90129
Some(RData::SRV(s)) => s,
91130
_ => continue,
92131
};
93-
94-
let hostname = srv.target().to_utf8();
95-
let port = Some(srv.port());
96-
let mut address = ServerAddress::Tcp {
97-
host: hostname,
98-
port,
99-
};
100-
101-
let domain_name = &hostname_parts[1..];
102-
103-
let host = address.host();
104-
let mut hostname_parts: Vec<_> = host.split('.').collect();
105-
106-
// Remove empty final section, which indicates a trailing dot.
107-
if hostname_parts.last().map(|s| s.is_empty()).unwrap_or(false) {
108-
hostname_parts.pop();
109-
}
110-
111-
if !&hostname_parts[1..].ends_with(domain_name) {
112-
let message = format!(
113-
"SRV lookup for {} returned result {}, which does not match domain name {}",
114-
original_hostname,
115-
address,
116-
domain_name.join(".")
117-
);
118-
if matches!(dm, DomainMismatch::Error) {
119-
return Err(ErrorKind::DnsResolve { message }.into());
120-
} else {
121-
#[cfg(feature = "tracing-unstable")]
122-
{
123-
use crate::trace::SERVER_SELECTION_TRACING_EVENT_TARGET;
124-
if crate::trace::trace_or_log_enabled!(
125-
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
126-
crate::trace::TracingOrLogLevel::Warn
127-
) {
128-
tracing::warn!(
129-
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
130-
message,
131-
);
132-
}
133-
}
134-
}
135-
continue;
132+
let mut host = srv.target().to_utf8();
133+
// Remove the trailing '.'
134+
if host.ends_with('.') {
135+
host.pop();
136136
}
137-
138-
// The spec tests list the seeds without the trailing '.', so we remove it by
139-
// joining the parts we split rather than manipulating the string.
140-
address = ServerAddress::Tcp {
141-
host: hostname_parts.join("."),
142-
port: address.port(),
143-
};
144-
137+
let port = Some(srv.port());
138+
hosts.push(ServerAddress::Tcp { host, port });
145139
min_ttl = std::cmp::min(min_ttl, record.ttl());
146-
srv_addresses.push(address);
147140
}
148-
149-
if srv_addresses.is_empty() {
150-
return Err(ErrorKind::DnsResolve {
151-
message: format!("SRV lookup for {} returned no records", original_hostname),
152-
}
153-
.into());
154-
}
155-
156141
Ok(LookupHosts {
157-
hosts: srv_addresses,
142+
hosts,
158143
min_ttl: Duration::from_secs(min_ttl.into()),
159144
})
160145
}
161146

147+
pub(crate) async fn get_srv_hosts(
148+
&self,
149+
original_hostname: &str,
150+
dm: DomainMismatch,
151+
) -> Result<LookupHosts> {
152+
let lookup_hostname = format!("_mongodb._tcp.{}", original_hostname);
153+
self.get_srv_hosts_unvalidated(&lookup_hostname)
154+
.await?
155+
.validate(original_hostname, dm)
156+
}
157+
162158
async fn get_txt_options(
163159
&self,
164160
original_hostname: &str,

src/test/spec/initial_dns_seedlist_discovery.rs

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use serde::Deserialize;
55
use crate::{
66
bson::doc,
77
client::Client,
8-
options::{ClientOptions, ResolverConfig},
8+
options::{ClientOptions, ResolverConfig, ServerAddress},
9+
srv::{DomainMismatch, LookupHosts},
910
test::{get_client_options, log_uncaptured, run_spec_test},
1011
};
1112

@@ -255,3 +256,45 @@ async fn sharded() {
255256
}
256257
run_spec_test(&["initial-dns-seedlist-discovery", "sharded"], run_test).await;
257258
}
259+
260+
fn validate_srv(original: &str, resolved: &str) -> crate::error::Result<()> {
261+
LookupHosts {
262+
hosts: vec![ServerAddress::Tcp {
263+
host: resolved.to_string(),
264+
port: Some(42),
265+
}],
266+
min_ttl: Duration::from_secs(1),
267+
}
268+
.validate(original, DomainMismatch::Error)
269+
.map(|_| ())
270+
}
271+
272+
// Prose test 1. Allow SRVs with fewer than 3 `.` separated parts
273+
#[test]
274+
fn short_srv_domains_valid() {
275+
validate_srv("localhost", "test.localhost").unwrap();
276+
validate_srv("mongo.local", "test.mongo.local").unwrap();
277+
}
278+
279+
// Prose test 2. Throw when return address does not end with SRV domain
280+
#[test]
281+
fn short_srv_domains_invalid_end() {
282+
assert!(validate_srv("localhost", "localhost.mongodb").is_err());
283+
assert!(validate_srv("mongo.local", "test_1.evil.local").is_err());
284+
assert!(validate_srv("blogs.mongodb.com", "blogs.evil.com").is_err());
285+
}
286+
287+
// Prose test 3. Throw when return address is identical to SRV hostname
288+
#[test]
289+
fn short_srv_domains_invalid_identical() {
290+
assert!(validate_srv("localhost", "localhost").is_err());
291+
assert!(validate_srv("mongo.local", "mongo.local").is_err());
292+
}
293+
294+
// Prose test 4. Throw when return address does not contain `.` separating shared part of domain
295+
#[test]
296+
fn short_srv_domains_invalid_no_dot() {
297+
assert!(validate_srv("localhost", "test_1.cluster_1localhost").is_err());
298+
assert!(validate_srv("mongo.local", "test_1.my_hostmongo.local").is_err());
299+
assert!(validate_srv("blogs.mongodb.com", "cluster.testmongodb.com").is_err());
300+
}

0 commit comments

Comments
 (0)