Skip to content

Commit 083828a

Browse files
authored
Merge pull request #1133 from jkczyz/2021-10-expand-scorer
Include source and destination nodes in routing::Score
2 parents 2bf39a6 + 54f490c commit 083828a

File tree

3 files changed

+95
-33
lines changed

3 files changed

+95
-33
lines changed

lightning/src/routing/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,13 @@ pub mod network_graph;
1313
pub mod router;
1414
pub mod scorer;
1515

16+
use routing::network_graph::NodeId;
17+
1618
/// An interface used to score payment channels for path finding.
1719
///
1820
/// Scoring is in terms of fees willing to be paid in order to avoid routing through a channel.
1921
pub trait Score {
20-
/// Returns the fee in msats willing to be paid to avoid routing through the given channel.
21-
fn channel_penalty_msat(&self, short_channel_id: u64) -> u64;
22+
/// Returns the fee in msats willing to be paid to avoid routing through the given channel
23+
/// in the direction from `source` to `target`.
24+
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64;
2225
}

lightning/src/routing/router.rs

Lines changed: 83 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ where L::Target: Logger {
748748
}
749749

750750
let path_penalty_msat = $next_hops_path_penalty_msat
751-
.checked_add(scorer.channel_penalty_msat($chan_id.clone()))
751+
.checked_add(scorer.channel_penalty_msat($chan_id.clone(), &$src_node_id, &$dest_node_id))
752752
.unwrap_or_else(|| u64::max_value());
753753
let new_graph_node = RouteGraphNode {
754754
node_id: $src_node_id,
@@ -973,15 +973,17 @@ where L::Target: Logger {
973973
_ => aggregate_next_hops_fee_msat.checked_add(999).unwrap_or(u64::max_value())
974974
}) { Some( val / 1000 ) } else { break; }; // converting from msat or breaking if max ~ infinity
975975

976+
let src_node_id = NodeId::from_pubkey(&hop.src_node_id);
977+
let dest_node_id = NodeId::from_pubkey(&prev_hop_id);
976978
aggregate_next_hops_path_penalty_msat = aggregate_next_hops_path_penalty_msat
977-
.checked_add(scorer.channel_penalty_msat(hop.short_channel_id))
979+
.checked_add(scorer.channel_penalty_msat(hop.short_channel_id, &src_node_id, &dest_node_id))
978980
.unwrap_or_else(|| u64::max_value());
979981

980982
// We assume that the recipient only included route hints for routes which had
981983
// sufficient value to route `final_value_msat`. Note that in the case of "0-value"
982984
// invoices where the invoice does not specify value this may not be the case, but
983985
// better to include the hints than not.
984-
if !add_entry!(hop.short_channel_id, NodeId::from_pubkey(&hop.src_node_id), NodeId::from_pubkey(&prev_hop_id), directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
986+
if !add_entry!(hop.short_channel_id, src_node_id, dest_node_id, directional_info, reqd_channel_cap, &empty_channel_features, aggregate_next_hops_fee_msat, path_value_msat, aggregate_next_hops_path_htlc_minimum_msat, aggregate_next_hops_path_penalty_msat) {
985987
// If this hop was not used then there is no use checking the preceding hops
986988
// in the RouteHint. We can break by just searching for a direct channel between
987989
// last checked hop and first_hop_targets
@@ -1322,7 +1324,8 @@ where L::Target: Logger {
13221324

13231325
#[cfg(test)]
13241326
mod tests {
1325-
use routing::network_graph::{NetworkGraph, NetGraphMsgHandler};
1327+
use routing;
1328+
use routing::network_graph::{NetworkGraph, NetGraphMsgHandler, NodeId};
13261329
use routing::router::{get_route, Route, RouteHint, RouteHintHop, RouteHop, RoutingFees};
13271330
use routing::scorer::Scorer;
13281331
use chain::transaction::OutPoint;
@@ -4351,42 +4354,92 @@ mod tests {
43514354
let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
43524355
let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
43534356

4357+
// Without penalizing each hop 100 msats, a longer path with lower fees is chosen.
4358+
let scorer = Scorer::new(0);
4359+
let route = get_route(
4360+
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
4361+
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
4362+
).unwrap();
4363+
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
4364+
4365+
assert_eq!(route.get_total_fees(), 100);
4366+
assert_eq!(route.get_total_amount(), 100);
4367+
assert_eq!(path, vec![2, 4, 6, 11, 8]);
4368+
43544369
// Applying a 100 msat penalty to each hop results in taking channels 7 and 10 to nodes[6]
43554370
// from nodes[2] rather than channel 6, 11, and 8, even though the longer path is cheaper.
43564371
let scorer = Scorer::new(100);
4357-
let route = get_route(&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None, &last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer).unwrap();
4358-
assert_eq!(route.paths[0].len(), 4);
4372+
let route = get_route(
4373+
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
4374+
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
4375+
).unwrap();
4376+
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
43594377

4360-
assert_eq!(route.paths[0][0].pubkey, nodes[1]);
4361-
assert_eq!(route.paths[0][0].short_channel_id, 2);
4362-
assert_eq!(route.paths[0][0].fee_msat, 200);
4363-
assert_eq!(route.paths[0][0].cltv_expiry_delta, (4 << 8) | 1);
4364-
assert_eq!(route.paths[0][0].node_features.le_flags(), &id_to_feature_flags(2));
4365-
assert_eq!(route.paths[0][0].channel_features.le_flags(), &id_to_feature_flags(2));
4378+
assert_eq!(route.get_total_fees(), 300);
4379+
assert_eq!(route.get_total_amount(), 100);
4380+
assert_eq!(path, vec![2, 4, 7, 10]);
4381+
}
43664382

4367-
assert_eq!(route.paths[0][1].pubkey, nodes[2]);
4368-
assert_eq!(route.paths[0][1].short_channel_id, 4);
4369-
assert_eq!(route.paths[0][1].fee_msat, 100);
4370-
assert_eq!(route.paths[0][1].cltv_expiry_delta, (7 << 8) | 1);
4371-
assert_eq!(route.paths[0][1].node_features.le_flags(), &id_to_feature_flags(3));
4372-
assert_eq!(route.paths[0][1].channel_features.le_flags(), &id_to_feature_flags(4));
4383+
struct BadChannelScorer {
4384+
short_channel_id: u64,
4385+
}
43734386

4374-
assert_eq!(route.paths[0][2].pubkey, nodes[5]);
4375-
assert_eq!(route.paths[0][2].short_channel_id, 7);
4376-
assert_eq!(route.paths[0][2].fee_msat, 0);
4377-
assert_eq!(route.paths[0][2].cltv_expiry_delta, (10 << 8) | 1);
4378-
assert_eq!(route.paths[0][2].node_features.le_flags(), &id_to_feature_flags(6));
4379-
assert_eq!(route.paths[0][2].channel_features.le_flags(), &id_to_feature_flags(7));
4387+
impl routing::Score for BadChannelScorer {
4388+
fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 {
4389+
if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 }
4390+
}
4391+
}
43804392

4381-
assert_eq!(route.paths[0][3].pubkey, nodes[6]);
4382-
assert_eq!(route.paths[0][3].short_channel_id, 10);
4383-
assert_eq!(route.paths[0][3].fee_msat, 100);
4384-
assert_eq!(route.paths[0][3].cltv_expiry_delta, 42);
4385-
assert_eq!(route.paths[0][3].node_features.le_flags(), &Vec::<u8>::new()); // We don't pass flags in from invoices yet
4386-
assert_eq!(route.paths[0][3].channel_features.le_flags(), &Vec::<u8>::new()); // We can't learn any flags from invoices, sadly
4393+
struct BadNodeScorer {
4394+
node_id: NodeId,
4395+
}
4396+
4397+
impl routing::Score for BadNodeScorer {
4398+
fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 {
4399+
if *target == self.node_id { u64::max_value() } else { 0 }
4400+
}
4401+
}
4402+
4403+
#[test]
4404+
fn avoids_routing_through_bad_channels_and_nodes() {
4405+
let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph();
4406+
let (_, our_id, _, nodes) = get_nodes(&secp_ctx);
4407+
4408+
// A path to nodes[6] exists when no penalties are applied to any channel.
4409+
let scorer = Scorer::new(0);
4410+
let route = get_route(
4411+
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
4412+
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
4413+
).unwrap();
4414+
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
4415+
4416+
assert_eq!(route.get_total_fees(), 100);
4417+
assert_eq!(route.get_total_amount(), 100);
4418+
assert_eq!(path, vec![2, 4, 6, 11, 8]);
4419+
4420+
// A different path to nodes[6] exists if channel 6 cannot be routed over.
4421+
let scorer = BadChannelScorer { short_channel_id: 6 };
4422+
let route = get_route(
4423+
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
4424+
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
4425+
).unwrap();
4426+
let path = route.paths[0].iter().map(|hop| hop.short_channel_id).collect::<Vec<_>>();
43874427

43884428
assert_eq!(route.get_total_fees(), 300);
43894429
assert_eq!(route.get_total_amount(), 100);
4430+
assert_eq!(path, vec![2, 4, 7, 10]);
4431+
4432+
// A path to nodes[6] does not exist if nodes[2] cannot be routed through.
4433+
let scorer = BadNodeScorer { node_id: NodeId::from_pubkey(&nodes[2]) };
4434+
match get_route(
4435+
&our_id, &net_graph_msg_handler.network_graph, &nodes[6], None, None,
4436+
&last_hops(&nodes).iter().collect::<Vec<_>>(), 100, 42, Arc::clone(&logger), &scorer
4437+
) {
4438+
Err(LightningError { err, .. } ) => {
4439+
assert_eq!(err, "Failed to find a path to the given destination");
4440+
},
4441+
Ok(_) => panic!("Expected error"),
4442+
}
43904443
}
43914444

43924445
#[test]

lightning/src/routing/scorer.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444
4545
use routing;
4646

47+
use routing::network_graph::NodeId;
48+
4749
/// [`routing::Score`] implementation that provides reasonable default behavior.
4850
///
4951
/// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with
@@ -71,5 +73,9 @@ impl Default for Scorer {
7173
}
7274

7375
impl routing::Score for Scorer {
74-
fn channel_penalty_msat(&self, _short_channel_id: u64) -> u64 { self.base_penalty_msat }
76+
fn channel_penalty_msat(
77+
&self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId
78+
) -> u64 {
79+
self.base_penalty_msat
80+
}
7581
}

0 commit comments

Comments
 (0)