@@ -41,13 +41,15 @@ pub fn find_path<L: Deref, GL: Deref>(
41
41
42
42
// Add our start and first-hops to `frontier`.
43
43
let start = NodeId :: from_pubkey ( & our_node_pubkey) ;
44
+ let mut valid_first_hops = HashSet :: new ( ) ;
44
45
let mut frontier = BinaryHeap :: new ( ) ;
45
46
frontier. push ( PathBuildingHop { cost : 0 , node_id : start, parent_node_id : start } ) ;
46
47
if let Some ( first_hops) = first_hops {
47
48
for hop in first_hops {
48
49
if !hop. counterparty . features . supports_onion_messages ( ) { continue ; }
49
50
let node_id = NodeId :: from_pubkey ( & hop. counterparty . node_id ) ;
50
51
frontier. push ( PathBuildingHop { cost : 1 , node_id, parent_node_id : start } ) ;
52
+ valid_first_hops. insert ( node_id) ;
51
53
}
52
54
}
53
55
@@ -61,7 +63,7 @@ pub fn find_path<L: Deref, GL: Deref>(
61
63
return Ok ( reverse_path ( visited, our_node_id, dest_node_id, logger) ?)
62
64
}
63
65
if let Some ( node_info) = network_nodes. get ( & node_id) {
64
- if node_id == our_node_id {
66
+ if valid_first_hops . contains ( & node_id ) || node_id == our_node_id {
65
67
} else if let Some ( node_ann) = & node_info. announcement_info {
66
68
if !node_ann. features . supports_onion_messages ( ) || node_ann. features . requires_unknown_bits ( )
67
69
{ continue ; }
@@ -149,7 +151,7 @@ fn reverse_path<L: Deref>(
149
151
#[ cfg( test) ]
150
152
mod tests {
151
153
use ln:: features:: { InitFeatures , NodeFeatures } ;
152
- use routing:: test_utils:: { add_or_update_node, build_graph_with_features, build_line_graph, get_nodes} ;
154
+ use routing:: test_utils:: { add_or_update_node, build_graph_with_features, build_line_graph, get_channel_details , get_nodes} ;
153
155
154
156
use sync:: Arc ;
155
157
@@ -222,6 +224,13 @@ mod tests {
222
224
// If all nodes require some features we don't understand, route should fail
223
225
let err = super :: find_path ( & our_id, & node_pks[ 2 ] , & network_graph, None , Arc :: clone ( & logger) ) . unwrap_err ( ) ;
224
226
assert_eq ! ( err, super :: Error :: PathNotFound ) ;
227
+
228
+ // If we specify a channel to node7, that overrides our local channel view and that gets used
229
+ let our_chans = vec ! [ get_channel_details( Some ( 42 ) , node_pks[ 7 ] . clone( ) , features, 250_000_000 ) ] ;
230
+ let path = super :: find_path ( & our_id, & node_pks[ 2 ] , & network_graph, Some ( & our_chans. iter ( ) . collect :: < Vec < _ > > ( ) ) , Arc :: clone ( & logger) ) . unwrap ( ) ;
231
+ assert_eq ! ( path. len( ) , 2 ) ;
232
+ assert_eq ! ( path[ 0 ] , node_pks[ 7 ] ) ;
233
+ assert_eq ! ( path[ 1 ] , node_pks[ 2 ] ) ;
225
234
}
226
235
227
236
#[ test]
@@ -239,6 +248,12 @@ mod tests {
239
248
assert_eq ! ( path[ 0 ] , node_pks[ 1 ] ) ;
240
249
assert_eq ! ( path[ 1 ] , node_pks[ 2 ] ) ;
241
250
assert_eq ! ( path[ 2 ] , node_pks[ 0 ] ) ;
251
+
252
+ // If we specify a channel to node1, that overrides our local channel view and that gets used
253
+ let our_chans = vec ! [ get_channel_details( Some ( 42 ) , node_pks[ 0 ] . clone( ) , features, 250_000_000 ) ] ;
254
+ let path = super :: find_path ( & our_id, & node_pks[ 0 ] , & network_graph, Some ( & our_chans. iter ( ) . collect :: < Vec < _ > > ( ) ) , Arc :: clone ( & logger) ) . unwrap ( ) ;
255
+ assert_eq ! ( path. len( ) , 1 ) ;
256
+ assert_eq ! ( path[ 0 ] , node_pks[ 0 ] ) ;
242
257
}
243
258
}
244
259
0 commit comments