Skip to content

Commit d86bf6e

Browse files
OM pathfinding: ensure first_hops overrides network graph
1 parent 948d019 commit d86bf6e

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

lightning/src/routing/onion_message.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,15 @@ pub fn find_path<L: Deref, GL: Deref>(
4949

5050
// Add our start and first-hops to `frontier`.
5151
let start = NodeId::from_pubkey(&our_node_pubkey);
52+
let mut valid_first_hops = HashSet::new();
5253
let mut frontier = BinaryHeap::new();
5354
frontier.push(PathBuildingHop { cost: 0, node_id: start, parent_node_id: start });
5455
if let Some(first_hops) = first_hops {
5556
for hop in first_hops {
5657
if !hop.counterparty.features.supports_onion_messages() { continue; }
5758
let node_id = NodeId::from_pubkey(&hop.counterparty.node_id);
5859
frontier.push(PathBuildingHop { cost: 1, node_id, parent_node_id: start });
60+
valid_first_hops.insert(node_id);
5961
}
6062
}
6163

@@ -71,7 +73,7 @@ pub fn find_path<L: Deref, GL: Deref>(
7173
return Ok(path)
7274
}
7375
if let Some(node_info) = network_nodes.get(&node_id) {
74-
if node_id == our_node_id {
76+
if valid_first_hops.contains(&node_id) || node_id == our_node_id {
7577
} else if let Some(node_ann) = &node_info.announcement_info {
7678
if !node_ann.features.supports_onion_messages() || node_ann.features.requires_unknown_bits()
7779
{ continue; }
@@ -166,7 +168,7 @@ fn reverse_path(
166168
#[cfg(test)]
167169
mod tests {
168170
use ln::features::{InitFeatures, NodeFeatures};
169-
use routing::test_utils::{add_or_update_node, build_graph_with_features, build_line_graph, get_nodes};
171+
use routing::test_utils::{add_or_update_node, build_graph_with_features, build_line_graph, get_channel_details, get_nodes};
170172

171173
use sync::Arc;
172174

@@ -239,6 +241,13 @@ mod tests {
239241
// If all nodes require some features we don't understand, route should fail
240242
let err = super::find_path(&our_id, &node_pks[2], &network_graph, None, Arc::clone(&logger)).unwrap_err();
241243
assert_eq!(err, super::Error::PathNotFound);
244+
245+
// If we specify a channel to node7, that overrides our local channel view and that gets used
246+
let our_chans = vec![get_channel_details(Some(42), node_pks[7].clone(), features, 250_000_000)];
247+
let path = super::find_path(&our_id, &node_pks[2], &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), Arc::clone(&logger)).unwrap();
248+
assert_eq!(path.len(), 2);
249+
assert_eq!(path[0], node_pks[7]);
250+
assert_eq!(path[1], node_pks[2]);
242251
}
243252

244253
#[test]
@@ -256,6 +265,12 @@ mod tests {
256265
assert_eq!(path[0], node_pks[1]);
257266
assert_eq!(path[1], node_pks[2]);
258267
assert_eq!(path[2], node_pks[0]);
268+
269+
// If we specify a channel to node1, that overrides our local channel view and that gets used
270+
let our_chans = vec![get_channel_details(Some(42), node_pks[0].clone(), features, 250_000_000)];
271+
let path = super::find_path(&our_id, &node_pks[0], &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), Arc::clone(&logger)).unwrap();
272+
assert_eq!(path.len(), 1);
273+
assert_eq!(path[0], node_pks[0]);
259274
}
260275
}
261276

0 commit comments

Comments
 (0)