Skip to content

Commit 846dd0f

Browse files
OM pathfinding: ensure first_hops overrides network graph
1 parent aba9e61 commit 846dd0f

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

lightning/src/routing/onion_message.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@ pub fn find_path<L: Deref, GL: Deref>(
5353

5454
// Add our start and first-hops to `frontier`.
5555
let start = NodeId::from_pubkey(&our_node_pubkey);
56+
let mut valid_first_hops = HashSet::new();
5657
let mut frontier = BinaryHeap::new();
5758
frontier.push(PathBuildingHop { cost: 0, node_id: start, parent_node_id: start });
5859
if let Some(first_hops) = first_hops {
5960
for hop in first_hops {
6061
if !hop.counterparty.features.supports_onion_messages() { continue; }
6162
let node_id = NodeId::from_pubkey(&hop.counterparty.node_id);
6263
frontier.push(PathBuildingHop { cost: 1, node_id, parent_node_id: start });
64+
valid_first_hops.insert(node_id);
6365
}
6466
}
6567

@@ -75,7 +77,7 @@ pub fn find_path<L: Deref, GL: Deref>(
7577
return Ok(path)
7678
}
7779
if let Some(node_info) = network_nodes.get(&node_id) {
78-
if node_id == our_node_id {
80+
if valid_first_hops.contains(&node_id) || node_id == our_node_id {
7981
} else if let Some(node_ann) = &node_info.announcement_info {
8082
if !node_ann.features.supports_onion_messages() || node_ann.features.requires_unknown_bits()
8183
{ continue; }
@@ -173,7 +175,7 @@ fn reverse_path(
173175
#[cfg(test)]
174176
mod tests {
175177
use ln::features::{InitFeatures, NodeFeatures};
176-
use routing::test_utils::{add_or_update_node, build_graph_with_features, build_line_graph, get_nodes};
178+
use routing::test_utils::{add_or_update_node, build_graph_with_features, build_line_graph, get_channel_details, get_nodes};
177179

178180
use sync::Arc;
179181

@@ -246,6 +248,13 @@ mod tests {
246248
// If all nodes require some features we don't understand, route should fail
247249
let err = super::find_path(&our_id, &node_pks[2], &network_graph, None, Arc::clone(&logger)).unwrap_err();
248250
assert_eq!(err, super::Error::PathNotFound);
251+
252+
// If we specify a channel to node7, that overrides our local channel view and that gets used
253+
let our_chans = vec![get_channel_details(Some(42), node_pks[7].clone(), features, 250_000_000)];
254+
let path = super::find_path(&our_id, &node_pks[2], &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), Arc::clone(&logger)).unwrap();
255+
assert_eq!(path.len(), 2);
256+
assert_eq!(path[0], node_pks[7]);
257+
assert_eq!(path[1], node_pks[2]);
249258
}
250259

251260
#[test]
@@ -263,5 +272,11 @@ mod tests {
263272
assert_eq!(path[0], node_pks[1]);
264273
assert_eq!(path[1], node_pks[2]);
265274
assert_eq!(path[2], node_pks[0]);
275+
276+
// If we specify a channel to node1, that overrides our local channel view and that gets used
277+
let our_chans = vec![get_channel_details(Some(42), node_pks[0].clone(), features, 250_000_000)];
278+
let path = super::find_path(&our_id, &node_pks[0], &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), Arc::clone(&logger)).unwrap();
279+
assert_eq!(path.len(), 1);
280+
assert_eq!(path[0], node_pks[0]);
266281
}
267282
}

0 commit comments

Comments
 (0)