@@ -19,6 +19,62 @@ pub(crate) struct LookupHosts {
19
19
pub ( crate ) min_ttl : Duration ,
20
20
}
21
21
22
+ impl LookupHosts {
23
+ pub ( crate ) fn validate ( mut self , original_hostname : & str , dm : DomainMismatch ) -> Result < Self > {
24
+ let original_hostname_parts: Vec < _ > = original_hostname. split ( '.' ) . collect ( ) ;
25
+ let original_domain_name = if original_hostname_parts. len ( ) >= 3 {
26
+ & original_hostname_parts[ 1 ..]
27
+ } else {
28
+ & original_hostname_parts[ ..]
29
+ } ;
30
+
31
+ let mut ok_hosts = vec ! [ ] ;
32
+ for addr in self . hosts . drain ( ..) {
33
+ let host = addr. host ( ) ;
34
+ let hostname_parts: Vec < _ > = host. split ( '.' ) . collect ( ) ;
35
+ if hostname_parts[ 1 ..] . ends_with ( original_domain_name) {
36
+ ok_hosts. push ( addr) ;
37
+ } else {
38
+ let message = format ! (
39
+ "SRV lookup for {} returned result {}, which does not match domain name {}" ,
40
+ original_hostname,
41
+ host,
42
+ original_domain_name. join( "." )
43
+ ) ;
44
+ match dm {
45
+ DomainMismatch :: Error => return Err ( ErrorKind :: DnsResolve { message } . into ( ) ) ,
46
+ DomainMismatch :: Skip => {
47
+ #[ cfg( feature = "tracing-unstable" ) ]
48
+ {
49
+ use crate :: trace:: SERVER_SELECTION_TRACING_EVENT_TARGET ;
50
+ if crate :: trace:: trace_or_log_enabled!(
51
+ target: SERVER_SELECTION_TRACING_EVENT_TARGET ,
52
+ crate :: trace:: TracingOrLogLevel :: Warn
53
+ ) {
54
+ tracing:: warn!(
55
+ target: SERVER_SELECTION_TRACING_EVENT_TARGET ,
56
+ message,
57
+ ) ;
58
+ }
59
+ }
60
+ continue ;
61
+ }
62
+ }
63
+ }
64
+ }
65
+ self . hosts = ok_hosts;
66
+
67
+ if self . hosts . is_empty ( ) {
68
+ return Err ( ErrorKind :: DnsResolve {
69
+ message : format ! ( "SRV lookup for {} returned no records" , original_hostname) ,
70
+ }
71
+ . into ( ) ) ;
72
+ }
73
+
74
+ Ok ( self )
75
+ }
76
+ }
77
+
22
78
#[ derive( Debug , Clone , PartialEq ) ]
23
79
pub ( crate ) struct OriginalSrvInfo {
24
80
pub ( crate ) hostname : String ,
@@ -62,103 +118,43 @@ impl SrvResolver {
62
118
Ok ( config)
63
119
}
64
120
65
- pub ( crate ) async fn get_srv_hosts (
66
- & self ,
67
- original_hostname : & str ,
68
- dm : DomainMismatch ,
69
- ) -> Result < LookupHosts > {
121
+ async fn get_srv_hosts_unvalidated ( & self , lookup_hostname : & str ) -> Result < LookupHosts > {
70
122
use hickory_proto:: rr:: RData ;
71
123
72
- let hostname_parts: Vec < _ > = original_hostname. split ( '.' ) . collect ( ) ;
73
-
74
- if hostname_parts. len ( ) < 3 {
75
- return Err ( ErrorKind :: InvalidArgument {
76
- message : "a 'mongodb+srv' hostname must have at least three '.'-delimited parts"
77
- . into ( ) ,
78
- }
79
- . into ( ) ) ;
80
- }
81
-
82
- let lookup_hostname = format ! ( "_mongodb._tcp.{}" , original_hostname) ;
83
-
84
- let srv_lookup = self . resolver . srv_lookup ( lookup_hostname. as_str ( ) ) . await ?;
85
- let mut srv_addresses: Vec < ServerAddress > = Vec :: new ( ) ;
124
+ let srv_lookup = self . resolver . srv_lookup ( lookup_hostname) . await ?;
125
+ let mut hosts = vec ! [ ] ;
86
126
let mut min_ttl = u32:: MAX ;
87
-
88
127
for record in srv_lookup. as_lookup ( ) . record_iter ( ) {
89
128
let srv = match record. data ( ) {
90
129
Some ( RData :: SRV ( s) ) => s,
91
130
_ => continue ,
92
131
} ;
93
-
94
- let hostname = srv. target ( ) . to_utf8 ( ) ;
95
- let port = Some ( srv. port ( ) ) ;
96
- let mut address = ServerAddress :: Tcp {
97
- host : hostname,
98
- port,
99
- } ;
100
-
101
- let domain_name = & hostname_parts[ 1 ..] ;
102
-
103
- let host = address. host ( ) ;
104
- let mut hostname_parts: Vec < _ > = host. split ( '.' ) . collect ( ) ;
105
-
106
- // Remove empty final section, which indicates a trailing dot.
107
- if hostname_parts. last ( ) . map ( |s| s. is_empty ( ) ) . unwrap_or ( false ) {
108
- hostname_parts. pop ( ) ;
109
- }
110
-
111
- if !& hostname_parts[ 1 ..] . ends_with ( domain_name) {
112
- let message = format ! (
113
- "SRV lookup for {} returned result {}, which does not match domain name {}" ,
114
- original_hostname,
115
- address,
116
- domain_name. join( "." )
117
- ) ;
118
- if matches ! ( dm, DomainMismatch :: Error ) {
119
- return Err ( ErrorKind :: DnsResolve { message } . into ( ) ) ;
120
- } else {
121
- #[ cfg( feature = "tracing-unstable" ) ]
122
- {
123
- use crate :: trace:: SERVER_SELECTION_TRACING_EVENT_TARGET ;
124
- if crate :: trace:: trace_or_log_enabled!(
125
- target: SERVER_SELECTION_TRACING_EVENT_TARGET ,
126
- crate :: trace:: TracingOrLogLevel :: Warn
127
- ) {
128
- tracing:: warn!(
129
- target: SERVER_SELECTION_TRACING_EVENT_TARGET ,
130
- message,
131
- ) ;
132
- }
133
- }
134
- }
135
- continue ;
132
+ let mut host = srv. target ( ) . to_utf8 ( ) ;
133
+ // Remove the trailing '.'
134
+ if host. ends_with ( '.' ) {
135
+ host. pop ( ) ;
136
136
}
137
-
138
- // The spec tests list the seeds without the trailing '.', so we remove it by
139
- // joining the parts we split rather than manipulating the string.
140
- address = ServerAddress :: Tcp {
141
- host : hostname_parts. join ( "." ) ,
142
- port : address. port ( ) ,
143
- } ;
144
-
137
+ let port = Some ( srv. port ( ) ) ;
138
+ hosts. push ( ServerAddress :: Tcp { host, port } ) ;
145
139
min_ttl = std:: cmp:: min ( min_ttl, record. ttl ( ) ) ;
146
- srv_addresses. push ( address) ;
147
140
}
148
-
149
- if srv_addresses. is_empty ( ) {
150
- return Err ( ErrorKind :: DnsResolve {
151
- message : format ! ( "SRV lookup for {} returned no records" , original_hostname) ,
152
- }
153
- . into ( ) ) ;
154
- }
155
-
156
141
Ok ( LookupHosts {
157
- hosts : srv_addresses ,
142
+ hosts,
158
143
min_ttl : Duration :: from_secs ( min_ttl. into ( ) ) ,
159
144
} )
160
145
}
161
146
147
+ pub ( crate ) async fn get_srv_hosts (
148
+ & self ,
149
+ original_hostname : & str ,
150
+ dm : DomainMismatch ,
151
+ ) -> Result < LookupHosts > {
152
+ let lookup_hostname = format ! ( "_mongodb._tcp.{}" , original_hostname) ;
153
+ self . get_srv_hosts_unvalidated ( & lookup_hostname)
154
+ . await ?
155
+ . validate ( original_hostname, dm)
156
+ }
157
+
162
158
async fn get_txt_options (
163
159
& self ,
164
160
original_hostname : & str ,
0 commit comments