Skip to content

Commit c2ff655

Browse files
OM pathfinding: ensure first_hops overrides network graph
1 parent d39966f commit c2ff655

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

lightning/src/routing/onion_message.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@ pub fn find_path<L: Deref, GL: Deref>(
4040

4141
// Add our start and first-hops to `frontier`.
4242
let start = NodeId::from_pubkey(&our_node_pubkey);
43+
let mut valid_first_hops = HashSet::new();
4344
let mut frontier = BinaryHeap::new();
4445
frontier.push(PathBuildingHop { cost: 0, node_id: start, parent_node_id: start });
4546
if let Some(first_hops) = first_hops {
4647
for hop in first_hops {
4748
if !hop.counterparty.features.supports_onion_messages() { continue; }
4849
let node_id = NodeId::from_pubkey(&hop.counterparty.node_id);
4950
frontier.push(PathBuildingHop { cost: 1, node_id, parent_node_id: start });
51+
valid_first_hops.insert(node_id);
5052
}
5153
}
5254

@@ -60,7 +62,7 @@ pub fn find_path<L: Deref, GL: Deref>(
6062
return Ok(reverse_path(visited, our_node_id, dest_node_id, logger)?)
6163
}
6264
if let Some(node_info) = network_nodes.get(&node_id) {
63-
if node_id == our_node_id {
65+
if valid_first_hops.contains(&node_id) || node_id == our_node_id {
6466
} else if let Some(node_ann) = &node_info.announcement_info {
6567
if !node_ann.features.supports_onion_messages() || node_ann.features.requires_unknown_bits()
6668
{ continue; }
@@ -146,7 +148,7 @@ fn reverse_path<L: Deref>(
146148
#[cfg(test)]
147149
mod tests {
148150
use ln::features::{InitFeatures, NodeFeatures};
149-
use routing::test_utils::{add_or_update_node, build_graph_with_features, build_line_graph, get_nodes};
151+
use routing::test_utils::{add_or_update_node, build_graph_with_features, build_line_graph, get_channel_details, get_nodes};
150152

151153
use sync::Arc;
152154

@@ -219,6 +221,13 @@ mod tests {
219221
// If all nodes require some features we don't understand, route should fail
220222
let err = super::find_path(&our_id, &node_pks[2], &network_graph, None, Arc::clone(&logger)).unwrap_err();
221223
assert_eq!(err, super::Error::PathNotFound);
224+
225+
// If we specify a channel to node7, that overrides our local channel view and that gets used
226+
let our_chans = vec![get_channel_details(Some(42), node_pks[7].clone(), features, 250_000_000)];
227+
let path = super::find_path(&our_id, &node_pks[2], &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), Arc::clone(&logger)).unwrap();
228+
assert_eq!(path.len(), 2);
229+
assert_eq!(path[0], node_pks[7]);
230+
assert_eq!(path[1], node_pks[2]);
222231
}
223232

224233
#[test]
@@ -236,6 +245,12 @@ mod tests {
236245
assert_eq!(path[0], node_pks[1]);
237246
assert_eq!(path[1], node_pks[2]);
238247
assert_eq!(path[2], node_pks[0]);
248+
249+
// If we specify a channel to node1, that overrides our local channel view and that gets used
250+
let our_chans = vec![get_channel_details(Some(42), node_pks[0].clone(), features, 250_000_000)];
251+
let path = super::find_path(&our_id, &node_pks[0], &network_graph, Some(&our_chans.iter().collect::<Vec<_>>()), Arc::clone(&logger)).unwrap();
252+
assert_eq!(path.len(), 1);
253+
assert_eq!(path[0], node_pks[0]);
239254
}
240255
}
241256

0 commit comments

Comments
 (0)