Skip to content

Commit 9372f1c

Browse files
OM pathfinding: ensure first_hops overrides network graph
1 parent 7ce59d3 commit 9372f1c

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

lightning/src/routing/onion_message.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ pub fn find_path<L: Deref, GL: Deref>(
4141

4242
// Add our start and first-hops to `frontier`.
4343
let start = NodeId::from_pubkey(&our_node_pubkey);
44+
let mut valid_first_hops = HashSet::new();
4445
let mut frontier = BinaryHeap::new();
4546
frontier.push(PathBuildingHop { cost: 0, node_id: start, parent_node_id: start });
4647
if let Some(first_hops) = first_hops {
4748
for hop in first_hops {
4849
if !hop.counterparty.features.supports_onion_messages() { continue; }
4950
let node_id = NodeId::from_pubkey(&hop.counterparty.node_id);
5051
frontier.push(PathBuildingHop { cost: 1, node_id, parent_node_id: start });
52+
valid_first_hops.insert(node_id);
5153
}
5254
}
5355

@@ -61,7 +63,7 @@ pub fn find_path<L: Deref, GL: Deref>(
6163
return Ok(reverse_path(visited, our_node_id, dest_node_id, logger)?)
6264
}
6365
if let Some(node_info) = network_nodes.get(&node_id) {
64-
if node_id == our_node_id {
66+
if valid_first_hops.contains(&node_id) || node_id == our_node_id {
6567
} else if let Some(node_ann) = &node_info.announcement_info {
6668
if !node_ann.features.supports_onion_messages() || node_ann.features.requires_unknown_bits()
6769
{ continue; }
@@ -149,7 +151,7 @@ fn reverse_path<L: Deref>(
149151
#[cfg(test)]
150152
mod tests {
151153
use ln::features::{InitFeatures, NodeFeatures};
152-
use routing::test_utils::{add_or_update_node, build_graph_with_features, build_line_graph, get_nodes};
154+
use routing::test_utils::{add_or_update_node, build_graph_with_features, build_line_graph, get_channel_details, get_nodes};
153155

154156
use sync::Arc;
155157

@@ -222,6 +224,13 @@ mod tests {
222224
// If all nodes require some features we don't understand, route should fail
223225
let err = super::find_path(&our_id, &node_pks[2], &network_graph, None, Arc::clone(&logger)).unwrap_err();
224226
assert_eq!(err, super::Error::PathNotFound);
227+
228+
// If we specify a channel to node7, that overrides our local channel view and that gets used
229+
let our_chans = vec![get_channel_details(Some(42), node_pks[7].clone(), features, 250_000_000)];
230+
let path = super::find_path(&our_id, &node_pks[2], &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), Arc::clone(&logger)).unwrap();
231+
assert_eq!(path.len(), 2);
232+
assert_eq!(path[0], node_pks[7]);
233+
assert_eq!(path[1], node_pks[2]);
225234
}
226235

227236
#[test]
@@ -239,6 +248,12 @@ mod tests {
239248
assert_eq!(path[0], node_pks[1]);
240249
assert_eq!(path[1], node_pks[2]);
241250
assert_eq!(path[2], node_pks[0]);
251+
252+
// If we specify a channel to node1, that overrides our local channel view and that gets used
253+
let our_chans = vec![get_channel_details(Some(42), node_pks[0].clone(), features, 250_000_000)];
254+
let path = super::find_path(&our_id, &node_pks[0], &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), Arc::clone(&logger)).unwrap();
255+
assert_eq!(path.len(), 1);
256+
assert_eq!(path[0], node_pks[0]);
242257
}
243258
}
244259

0 commit comments

Comments
 (0)