diff --git a/fuzz/src/full_stack.rs b/fuzz/src/full_stack.rs index b01506871ec..2e447dac6da 100644 --- a/fuzz/src/full_stack.rs +++ b/fuzz/src/full_stack.rs @@ -382,7 +382,7 @@ pub fn do_test(data: &[u8], logger: &Arc) { let our_id = PublicKey::from_secret_key(&Secp256k1::signing_only(), &keys_manager.get_node_secret()); let network_graph = NetworkGraph::new(genesis_block(network).block_hash()); let net_graph_msg_handler = Arc::new(NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger))); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let peers = RefCell::new([false; 256]); let mut loss_detector = MoneyLossDetector::new(&peers, channelmanager.clone(), monitor.clone(), PeerManager::new(MessageHandler { diff --git a/fuzz/src/router.rs b/fuzz/src/router.rs index 7f7d9585cc2..abd83fa58c6 100644 --- a/fuzz/src/router.rs +++ b/fuzz/src/router.rs @@ -248,7 +248,7 @@ pub fn do_test(data: &[u8], out: Out) { }])); } } - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); for target in node_pks.iter() { let params = RouteParameters { payee: Payee::new(*target).with_route_hints(last_hops.clone()), diff --git a/lightning-background-processor/Cargo.toml b/lightning-background-processor/Cargo.toml index 4e45bb2a83c..d868f14db74 100644 --- a/lightning-background-processor/Cargo.toml +++ b/lightning-background-processor/Cargo.toml @@ -16,4 +16,4 @@ lightning-persister = { version = "0.0.102", path = "../lightning-persister" } [dev-dependencies] lightning = { version = "0.0.102", path = "../lightning", features = ["_test_utils"] } - +lightning-invoice = { version = "0.10.0", path = "../lightning-invoice" } diff --git a/lightning-background-processor/src/lib.rs b/lightning-background-processor/src/lib.rs index e38a4a975b2..3d26ec84a75 100644 --- a/lightning-background-processor/src/lib.rs +++ b/lightning-background-processor/src/lib.rs @@ -174,7 +174,7 @@ impl BackgroundProcessor { Descriptor: 'static + SocketDescriptor + Send + Sync, CMH: 'static + Deref + Send + Sync, RMH: 'static + Deref + Send + Sync, - EH: 'static + EventHandler + Send + Sync, + EH: 'static + EventHandler + Send, CMP: 'static + Send + ChannelManagerPersister, M: 'static + Deref> + Send + Sync, CM: 'static + Deref> + Send + Sync, @@ -311,11 +311,14 @@ mod tests { use lightning::ln::features::InitFeatures; use lightning::ln::msgs::{ChannelMessageHandler, Init}; use lightning::ln::peer_handler::{PeerManager, MessageHandler, SocketDescriptor, IgnoringMessageHandler}; + use lightning::routing::scorer::Scorer; use lightning::routing::network_graph::{NetworkGraph, NetGraphMsgHandler}; use lightning::util::config::UserConfig; use lightning::util::events::{Event, MessageSendEventsProvider, MessageSendEvent}; use lightning::util::ser::Writeable; use lightning::util::test_utils; + use lightning_invoice::payment::{InvoicePayer, RetryAttempts}; + use lightning_invoice::utils::DefaultRouter; use lightning_persister::FilesystemPersister; use std::fs; use std::path::PathBuf; @@ -621,4 +624,20 @@ mod tests { assert!(bg_processor.stop().is_ok()); } + + #[test] + fn test_invoice_payer() { + let nodes = create_nodes(2, "test_invoice_payer".to_string()); + + // Initiate the background processors to watch each node. + let data_dir = nodes[0].persister.get_data_dir(); + let persister = move |node: &ChannelManager, Arc, Arc, Arc, Arc>| FilesystemPersister::persist_manager(data_dir.clone(), node); + let network_graph = Arc::new(NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash())); + let router = DefaultRouter::new(network_graph, Arc::clone(&nodes[0].logger)); + let scorer = Arc::new(Mutex::new(Scorer::default())); + let invoice_payer = Arc::new(InvoicePayer::new(Arc::clone(&nodes[0].node), router, scorer, Arc::clone(&nodes[0].logger), |_: &_| {}, RetryAttempts(2))); + let event_handler = Arc::clone(&invoice_payer); + let bg_processor = BackgroundProcessor::start(persister, event_handler, nodes[0].chain_monitor.clone(), nodes[0].node.clone(), nodes[0].net_graph_msg_handler.clone(), nodes[0].peer_manager.clone(), nodes[0].logger.clone()); + assert!(bg_processor.stop().is_ok()); + } } diff --git a/lightning-invoice/src/payment.rs b/lightning-invoice/src/payment.rs index 7e931d66f15..ba260742bb4 100644 --- a/lightning-invoice/src/payment.rs +++ b/lightning-invoice/src/payment.rs @@ -30,12 +30,15 @@ //! # use lightning::ln::{PaymentHash, PaymentSecret}; //! # use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure}; //! # use lightning::ln::msgs::LightningError; -//! # use lightning::routing::router::{Route, RouteParameters}; +//! # use lightning::routing; +//! # use lightning::routing::network_graph::NodeId; +//! # use lightning::routing::router::{Route, RouteHop, RouteParameters}; //! # use lightning::util::events::{Event, EventHandler, EventsProvider}; //! # use lightning::util::logger::{Logger, Record}; //! # use lightning_invoice::Invoice; //! # use lightning_invoice::payment::{InvoicePayer, Payer, RetryAttempts, Router}; //! # use secp256k1::key::PublicKey; +//! # use std::cell::RefCell; //! # use std::ops::Deref; //! # //! # struct FakeEventProvider {} @@ -56,13 +59,21 @@ //! # } //! # //! # struct FakeRouter {}; -//! # impl Router for FakeRouter { +//! # impl Router for FakeRouter { //! # fn find_route( //! # &self, payer: &PublicKey, params: &RouteParameters, -//! # first_hops: Option<&[&ChannelDetails]> +//! # first_hops: Option<&[&ChannelDetails]>, scorer: &S //! # ) -> Result { unimplemented!() } //! # } //! # +//! # struct FakeScorer {}; +//! # impl routing::Score for FakeScorer { +//! # fn channel_penalty_msat( +//! # &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId +//! # ) -> u64 { 0 } +//! # fn payment_path_failed(&mut self, _path: &Vec, _short_channel_id: u64) {} +//! # } +//! # //! # struct FakeLogger {}; //! # impl Logger for FakeLogger { //! # fn log(&self, record: &Record) { unimplemented!() } @@ -78,8 +89,9 @@ //! }; //! # let payer = FakePayer {}; //! # let router = FakeRouter {}; +//! # let scorer = RefCell::new(FakeScorer {}); //! # let logger = FakeLogger {}; -//! let invoice_payer = InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); +//! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); //! //! let invoice = "..."; //! let invoice = invoice.parse::().unwrap(); @@ -105,6 +117,8 @@ use bitcoin_hashes::Hash; use lightning::ln::{PaymentHash, PaymentSecret}; use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure}; use lightning::ln::msgs::LightningError; +use lightning::routing; +use lightning::routing::{LockableScore, Score}; use lightning::routing::router::{Payee, Route, RouteParameters}; use lightning::util::events::{Event, EventHandler}; use lightning::util::logger::Logger; @@ -117,15 +131,17 @@ use std::sync::Mutex; use std::time::{Duration, SystemTime}; /// A utility for paying [`Invoice]`s. -pub struct InvoicePayer +pub struct InvoicePayer where P::Target: Payer, - R: Router, + R: for <'a> Router<<::Target as routing::LockableScore<'a>>::Locked>, + S::Target: for <'a> routing::LockableScore<'a>, L::Target: Logger, E: EventHandler, { payer: P, router: R, + scorer: S, logger: L, event_handler: E, payment_cache: Mutex>, @@ -150,10 +166,11 @@ pub trait Payer { } /// A trait defining behavior for routing an [`Invoice`] payment. -pub trait Router { +pub trait Router { /// Finds a [`Route`] between `payer` and `payee` for a payment with the given values. fn find_route( - &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]> + &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>, + scorer: &S ) -> Result; } @@ -172,10 +189,11 @@ pub enum PaymentError { Sending(PaymentSendFailure), } -impl InvoicePayer +impl InvoicePayer where P::Target: Payer, - R: Router, + R: for <'a> Router<<::Target as routing::LockableScore<'a>>::Locked>, + S::Target: for <'a> routing::LockableScore<'a>, L::Target: Logger, E: EventHandler, { @@ -184,11 +202,12 @@ where /// Will forward any [`Event::PaymentPathFailed`] events to the decorated `event_handler` once /// `retry_attempts` has been exceeded for a given [`Invoice`]. pub fn new( - payer: P, router: R, logger: L, event_handler: E, retry_attempts: RetryAttempts + payer: P, router: R, scorer: S, logger: L, event_handler: E, retry_attempts: RetryAttempts ) -> Self { Self { payer, router, + scorer, logger, event_handler, payment_cache: Mutex::new(HashMap::new()), @@ -242,6 +261,7 @@ where &payer, ¶ms, Some(&first_hops.iter().collect::>()), + &self.scorer.lock(), ).map_err(|e| PaymentError::Routing(e))?; let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner()); @@ -261,7 +281,8 @@ where let payer = self.payer.node_id(); let first_hops = self.payer.first_hops(); let route = self.router.find_route( - &payer, ¶ms, Some(&first_hops.iter().collect::>()) + &payer, ¶ms, Some(&first_hops.iter().collect::>()), + &self.scorer.lock() ).map_err(|e| PaymentError::Routing(e))?; self.payer.retry_payment(&route, payment_id).map_err(|e| PaymentError::Sending(e)) } @@ -284,16 +305,23 @@ fn has_expired(params: &RouteParameters) -> bool { Invoice::is_expired_from_epoch(&SystemTime::UNIX_EPOCH, expiry_time) } -impl EventHandler for InvoicePayer +impl EventHandler for InvoicePayer where P::Target: Payer, - R: Router, + R: for <'a> Router<<::Target as routing::LockableScore<'a>>::Locked>, + S::Target: for <'a> routing::LockableScore<'a>, L::Target: Logger, E: EventHandler, { fn handle_event(&self, event: &Event) { match event { - Event::PaymentPathFailed { payment_id, payment_hash, rejected_by_dest, retry, .. } => { + Event::PaymentPathFailed { + payment_id, payment_hash, rejected_by_dest, path, short_channel_id, retry, .. + } => { + if let Some(short_channel_id) = short_channel_id { + self.scorer.lock().payment_path_failed(path, *short_channel_id); + } + let mut payment_cache = self.payment_cache.lock().unwrap(); let entry = loop { let entry = payment_cache.entry(*payment_hash); @@ -354,11 +382,13 @@ mod tests { use lightning::ln::PaymentPreimage; use lightning::ln::features::{ChannelFeatures, NodeFeatures}; use lightning::ln::msgs::{ErrorAction, LightningError}; - use lightning::routing::router::{Route, RouteHop}; + use lightning::routing::network_graph::NodeId; + use lightning::routing::router::{Payee, Route, RouteHop}; use lightning::util::test_utils::TestLogger; use lightning::util::errors::APIError; use lightning::util::events::Event; use secp256k1::{SecretKey, PublicKey, Secp256k1}; + use std::cell::RefCell; use std::time::{SystemTime, Duration}; fn invoice(payment_preimage: PaymentPreimage) -> Invoice { @@ -422,9 +452,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -450,9 +481,10 @@ mod tests { .expect_value_msat(final_value_msat) .expect_value_msat(final_value_msat / 2); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -490,9 +522,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); let payment_id = Some(PaymentId([1; 32])); let event = Event::PaymentPathFailed { @@ -534,9 +567,10 @@ mod tests { .expect_value_msat(final_value_msat / 2) .expect_value_msat(final_value_msat / 2); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -583,9 +617,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -614,9 +649,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = expired_invoice(payment_preimage); @@ -651,9 +687,10 @@ mod tests { .fails_on_attempt(2) .expect_value_msat(final_value_msat); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); assert_eq!(*payer.attempts.borrow(), 1); @@ -680,9 +717,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(2)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -711,9 +749,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -751,9 +790,10 @@ mod tests { fn fails_paying_invoice_with_routing_errors() { let payer = TestPayer::new(); let router = FailingRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -768,9 +808,10 @@ mod tests { fn fails_paying_invoice_with_sending_errors() { let payer = TestPayer::new().fails_on_attempt(1); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, |_: &_| {}, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -793,9 +834,10 @@ mod tests { let payer = TestPayer::new().expect_value_msat(final_value_msat); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); let payment_id = Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap()); @@ -815,9 +857,10 @@ mod tests { let payer = TestPayer::new(); let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new()); let logger = TestLogger::new(); let invoice_payer = - InvoicePayer::new(&payer, router, &logger, event_handler, RetryAttempts(0)); + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0)); let payment_preimage = PaymentPreimage([1; 32]); let invoice = invoice(payment_preimage); @@ -830,6 +873,40 @@ mod tests { } } + #[test] + fn scores_failed_channel() { + let event_handled = core::cell::RefCell::new(false); + let event_handler = |_: &_| { *event_handled.borrow_mut() = true; }; + + let payment_preimage = PaymentPreimage([1; 32]); + let invoice = invoice(payment_preimage); + let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner()); + let final_value_msat = invoice.amount_milli_satoshis().unwrap(); + let path = TestRouter::path_for_value(final_value_msat); + let short_channel_id = Some(path[0].short_channel_id); + + // Expect that scorer is given short_channel_id upon handling the event. + let payer = TestPayer::new(); + let router = TestRouter {}; + let scorer = RefCell::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap())); + let logger = TestLogger::new(); + let invoice_payer = + InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2)); + + let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap()); + let event = Event::PaymentPathFailed { + payment_id, + payment_hash, + network_update: None, + rejected_by_dest: false, + all_paths_failed: false, + path, + short_channel_id, + retry: Some(TestRouter::retry_for_invoice(&invoice)), + }; + invoice_payer.handle_event(&event); + } + struct TestRouter; impl TestRouter { @@ -873,12 +950,13 @@ mod tests { } } - impl Router for TestRouter { + impl Router for TestRouter { fn find_route( &self, _payer: &PublicKey, params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>, + _scorer: &S, ) -> Result { Ok(Route { payee: Some(params.payee.clone()), ..Self::route_for_value(params.final_value_msat) @@ -888,17 +966,59 @@ mod tests { struct FailingRouter; - impl Router for FailingRouter { + impl Router for FailingRouter { fn find_route( &self, _payer: &PublicKey, _params: &RouteParameters, _first_hops: Option<&[&ChannelDetails]>, + _scorer: &S, ) -> Result { Err(LightningError { err: String::new(), action: ErrorAction::IgnoreError }) } } + struct TestScorer { + expectations: std::collections::VecDeque, + } + + impl TestScorer { + fn new() -> Self { + Self { + expectations: std::collections::VecDeque::new(), + } + } + + fn expect_channel_failure(mut self, short_channel_id: u64) -> Self { + self.expectations.push_back(short_channel_id); + self + } + } + + impl routing::Score for TestScorer { + fn channel_penalty_msat( + &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId + ) -> u64 { 0 } + + fn payment_path_failed(&mut self, _path: &Vec, short_channel_id: u64) { + if let Some(expected_short_channel_id) = self.expectations.pop_front() { + assert_eq!(short_channel_id, expected_short_channel_id); + } + } + } + + impl Drop for TestScorer { + fn drop(&mut self) { + if std::thread::panicking() { + return; + } + + if !self.expectations.is_empty() { + panic!("Unsatisfied channel failure expectations: {:?}", self.expectations); + } + } + } + struct TestPayer { expectations: core::cell::RefCell>, attempts: core::cell::RefCell, diff --git a/lightning-invoice/src/utils.rs b/lightning-invoice/src/utils.rs index ef885f20381..8da9994a3f7 100644 --- a/lightning-invoice/src/utils.rs +++ b/lightning-invoice/src/utils.rs @@ -11,9 +11,9 @@ use lightning::chain::keysinterface::{Sign, KeysInterface}; use lightning::ln::{PaymentHash, PaymentSecret}; use lightning::ln::channelmanager::{ChannelDetails, ChannelManager, PaymentId, PaymentSendFailure, MIN_FINAL_CLTV_EXPIRY}; use lightning::ln::msgs::LightningError; +use lightning::routing; use lightning::routing::network_graph::{NetworkGraph, RoutingFees}; use lightning::routing::router::{Route, RouteHint, RouteHintHop, RouteParameters, find_route}; -use lightning::routing::scorer::Scorer; use lightning::util::logger::Logger; use secp256k1::key::PublicKey; use std::convert::TryInto; @@ -109,13 +109,13 @@ impl DefaultRouter where G: Deref, L:: } } -impl Router for DefaultRouter +impl Router for DefaultRouter where G: Deref, L::Target: Logger { fn find_route( &self, payer: &PublicKey, params: &RouteParameters, first_hops: Option<&[&ChannelDetails]>, + scorer: &S ) -> Result { - let scorer = Scorer::default(); - find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, &scorer) + find_route(payer, params, &*self.network_graph, first_hops, &*self.logger, scorer) } } @@ -183,7 +183,7 @@ mod test { let first_hops = nodes[0].node.list_usable_channels(); let network_graph = &nodes[0].net_graph_msg_handler.network_graph; let logger = test_utils::TestLogger::new(); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = find_route( &nodes[0].node.get_our_node_id(), ¶ms, network_graph, Some(&first_hops.iter().collect::>()), &logger, &scorer, diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 90f1a2a91b2..e6c27efc8ec 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -6279,7 +6279,7 @@ mod tests { let node_chanmgrs = create_node_chanmgrs(2, &node_cfgs, &[None, None]); let nodes = create_network(2, &node_cfgs, &node_chanmgrs); create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // To start (1), send a regular payment but don't claim it. let expected_route = [&nodes[1]]; @@ -6384,7 +6384,7 @@ mod tests { }; let network_graph = &nodes[0].net_graph_msg_handler.network_graph; let first_hops = nodes[0].node.list_usable_channels(); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = find_route( &payer_pubkey, ¶ms, network_graph, Some(&first_hops.iter().collect::>()), nodes[0].logger, &scorer @@ -6427,7 +6427,7 @@ mod tests { }; let network_graph = &nodes[0].net_graph_msg_handler.network_graph; let first_hops = nodes[0].node.list_usable_channels(); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = find_route( &payer_pubkey, ¶ms, network_graph, Some(&first_hops.iter().collect::>()), nodes[0].logger, &scorer @@ -6602,7 +6602,7 @@ pub mod bench { let usable_channels = $node_a.list_usable_channels(); let payee = Payee::new($node_b.get_our_node_id()) .with_features(InvoiceFeatures::known()); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = get_route(&$node_a.get_our_node_id(), &payee, &dummy_graph, Some(&usable_channels.iter().map(|r| r).collect::>()), 10_000, TEST_FINAL_CLTV, &logger_a, &scorer).unwrap(); diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 014cf3b7ddb..03fa562555e 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -1015,7 +1015,7 @@ macro_rules! get_route_and_payment_hash { .with_features($crate::ln::features::InvoiceFeatures::known()) .with_route_hints($last_hops); let net_graph_msg_handler = &$send_node.net_graph_msg_handler; - let scorer = ::routing::scorer::Scorer::new(0); + let scorer = ::routing::scorer::Scorer::with_fixed_penalty(0); let route = ::routing::router::get_route( &$send_node.node.get_our_node_id(), &payee, &net_graph_msg_handler.network_graph, Some(&$send_node.node.list_usable_channels().iter().collect::>()), @@ -1339,7 +1339,7 @@ pub fn route_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: let payee = Payee::new(expected_route.last().unwrap().node.get_our_node_id()) .with_features(InvoiceFeatures::known()); let net_graph_msg_handler = &origin_node.net_graph_msg_handler; - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = get_route( &origin_node.node.get_our_node_id(), &payee, &net_graph_msg_handler.network_graph, Some(&origin_node.node.list_usable_channels().iter().collect::>()), @@ -1358,7 +1358,7 @@ pub fn route_over_limit<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_rou let payee = Payee::new(expected_route.last().unwrap().node.get_our_node_id()) .with_features(InvoiceFeatures::known()); let net_graph_msg_handler = &origin_node.net_graph_msg_handler; - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = get_route(&origin_node.node.get_our_node_id(), &payee, &net_graph_msg_handler.network_graph, None, recv_value, TEST_FINAL_CLTV, origin_node.logger, &scorer).unwrap(); assert_eq!(route.paths.len(), 1); assert_eq!(route.paths[0].len(), expected_route.len()); diff --git a/lightning/src/ln/functional_tests.rs b/lightning/src/ln/functional_tests.rs index 564a9e5a090..76190b671ff 100644 --- a/lightning/src/ln/functional_tests.rs +++ b/lightning/src/ln/functional_tests.rs @@ -7161,7 +7161,7 @@ fn test_check_htlc_underpaying() { // Create some initial channels create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[1].node.get_our_node_id()).with_features(InvoiceFeatures::known()); let route = get_route(&nodes[0].node.get_our_node_id(), &payee, &nodes[0].net_graph_msg_handler.network_graph, None, 10_000, TEST_FINAL_CLTV, nodes[0].logger, &scorer).unwrap(); let (_, our_payment_hash, _) = get_payment_preimage_hash!(nodes[0]); @@ -7561,7 +7561,7 @@ fn test_bump_penalty_txn_on_revoked_htlcs() { let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1000000, 59000000, InitFeatures::known(), InitFeatures::known()); // Lock HTLC in both directions (using a slightly lower CLTV delay to provide timely RBF bumps) let payee = Payee::new(nodes[1].node.get_our_node_id()).with_features(InvoiceFeatures::known()); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = get_route(&nodes[0].node.get_our_node_id(), &payee, &nodes[0].net_graph_msg_handler.network_graph, None, 3_000_000, 50, nodes[0].logger, &scorer).unwrap(); let payment_preimage = send_along_route(&nodes[0], route, &[&nodes[1]], 3_000_000).0; @@ -9061,7 +9061,7 @@ fn test_keysend_payments_to_public_node() { final_value_msat: 10000, final_cltv_expiry_delta: 40, }; - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = find_route(&payer_pubkey, ¶ms, &network_graph, None, nodes[0].logger, &scorer).unwrap(); let test_preimage = PaymentPreimage([42; 32]); @@ -9095,7 +9095,7 @@ fn test_keysend_payments_to_private_node() { }; let network_graph = &nodes[0].net_graph_msg_handler.network_graph; let first_hops = nodes[0].node.list_usable_channels(); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = find_route( &payer_pubkey, ¶ms, &network_graph, Some(&first_hops.iter().collect::>()), nodes[0].logger, &scorer diff --git a/lightning/src/ln/shutdown_tests.rs b/lightning/src/ln/shutdown_tests.rs index 388104bf023..06833ef4509 100644 --- a/lightning/src/ln/shutdown_tests.rs +++ b/lightning/src/ln/shutdown_tests.rs @@ -82,7 +82,7 @@ fn updates_shutdown_wait() { let chan_1 = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()); let chan_2 = create_announced_chan_between_nodes(&nodes, 1, 2, InitFeatures::known(), InitFeatures::known()); let logger = test_utils::TestLogger::new(); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let (our_payment_preimage, our_payment_hash, _) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], 100000); diff --git a/lightning/src/routing/mod.rs b/lightning/src/routing/mod.rs index 51ffd91b504..d6c016468ce 100644 --- a/lightning/src/routing/mod.rs +++ b/lightning/src/routing/mod.rs @@ -14,6 +14,12 @@ pub mod router; pub mod scorer; use routing::network_graph::NodeId; +use routing::router::RouteHop; + +use prelude::*; +use core::cell::{RefCell, RefMut}; +use core::ops::DerefMut; +use sync::{Mutex, MutexGuard}; /// An interface used to score payment channels for path finding. /// @@ -22,4 +28,49 @@ pub trait Score { /// Returns the fee in msats willing to be paid to avoid routing through the given channel /// in the direction from `source` to `target`. fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64; + + /// Handles updating channel penalties after failing to route through a channel. + fn payment_path_failed(&mut self, path: &Vec, short_channel_id: u64); +} + +/// A scorer that is accessed under a lock. +/// +/// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while +/// having shared ownership of a scorer but without requiring internal locking in [`Score`] +/// implementations. Internal locking would be detrimental to route finding performance and could +/// result in [`Score::channel_penalty_msat`] returning a different value for the same channel. +/// +/// [`find_route`]: crate::routing::router::find_route +pub trait LockableScore<'a> { + /// The locked [`Score`] type. + type Locked: 'a + Score; + + /// Returns the locked scorer. + fn lock(&'a self) -> Self::Locked; +} + +impl<'a, T: 'a + Score> LockableScore<'a> for Mutex { + type Locked = MutexGuard<'a, T>; + + fn lock(&'a self) -> MutexGuard<'a, T> { + Mutex::lock(self).unwrap() + } +} + +impl<'a, T: 'a + Score> LockableScore<'a> for RefCell { + type Locked = RefMut<'a, T>; + + fn lock(&'a self) -> RefMut<'a, T> { + self.borrow_mut() + } +} + +impl> Score for T { + fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 { + self.deref().channel_penalty_msat(short_channel_id, source, target) + } + + fn payment_path_failed(&mut self, path: &Vec, short_channel_id: u64) { + self.deref_mut().payment_path_failed(path, short_channel_id) + } } diff --git a/lightning/src/routing/router.rs b/lightning/src/routing/router.rs index 545ff9f24ce..93128c02401 100644 --- a/lightning/src/routing/router.rs +++ b/lightning/src/routing/router.rs @@ -1928,7 +1928,7 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, _, nodes) = get_nodes(&secp_ctx); let payee = Payee::new(nodes[2]); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // Simple route to 2 via 1 @@ -1959,7 +1959,7 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, _, nodes) = get_nodes(&secp_ctx); let payee = Payee::new(nodes[2]); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // Simple route to 2 via 1 @@ -1978,7 +1978,7 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); let payee = Payee::new(nodes[2]); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // Simple route to 2 via 1 @@ -2103,7 +2103,7 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known()); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // A route to node#2 via two paths. // One path allows transferring 35-40 sats, another one also allows 35-40 sats. @@ -2239,7 +2239,7 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); let payee = Payee::new(nodes[2]); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // // Disable channels 4 and 12 by flags=2 update_channel(&net_graph_msg_handler, &secp_ctx, &privkeys[1], UnsignedChannelUpdate { @@ -2297,7 +2297,7 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx); let payee = Payee::new(nodes[2]); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // Disable nodes 1, 2, and 8 by requiring unknown feature bits let unknown_features = NodeFeatures::known().set_unknown_feature_required(); @@ -2338,7 +2338,7 @@ mod tests { fn our_chans_test() { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, _, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // Route to 1 via 2 and 3 because our channel to 1 is disabled let payee = Payee::new(nodes[0]); @@ -2467,7 +2467,7 @@ mod tests { fn partial_route_hint_test() { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, _, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // Simple test across 2, 3, 5, and 4 via a last_hop channel // Tests the behaviour when the RouteHint contains a suboptimal hop. @@ -2566,7 +2566,7 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, _, nodes) = get_nodes(&secp_ctx); let payee = Payee::new(nodes[6]).with_route_hints(empty_last_hop(&nodes)); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // Test handling of an empty RouteHint passed in Invoice. @@ -2648,7 +2648,7 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, privkeys, nodes) = get_nodes(&secp_ctx); let payee = Payee::new(nodes[6]).with_route_hints(multi_hint_last_hops(&nodes)); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // Test through channels 2, 3, 5, 8. // Test shows that multiple hop hints are considered. @@ -2754,7 +2754,7 @@ mod tests { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, _, nodes) = get_nodes(&secp_ctx); let payee = Payee::new(nodes[6]).with_route_hints(last_hops_with_public_channel(&nodes)); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // This test shows that public routes can be present in the invoice // which would be handled in the same manner. @@ -2803,7 +2803,7 @@ mod tests { fn our_chans_last_hop_connect_test() { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (_, our_id, _, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // Simple test with outbound channel to 4 to test that last_hops and first_hops connect let our_chans = vec![get_channel_details(Some(42), nodes[3].clone(), InitFeatures::from_le_bytes(vec![0b11]), 250_000_000)]; @@ -2924,7 +2924,7 @@ mod tests { }]); let payee = Payee::new(target_node_id).with_route_hints(vec![last_hops]); let our_chans = vec![get_channel_details(Some(42), middle_node_id, InitFeatures::from_le_bytes(vec![0b11]), outbound_capacity_msat)]; - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); get_route(&source_node_id, &payee, &NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()), Some(&our_chans.iter().collect::>()), route_val, 42, &test_utils::TestLogger::new(), &scorer) } @@ -2978,7 +2978,7 @@ mod tests { let (secp_ctx, mut net_graph_msg_handler, chain_monitor, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known()); // We will use a simple single-path route from @@ -3250,7 +3250,7 @@ mod tests { // one of the latter hops is limited. let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[3]).with_features(InvoiceFeatures::known()); // Path via {node7, node2, node4} is channels {12, 13, 6, 11}. @@ -3373,7 +3373,7 @@ mod tests { fn ignore_fee_first_hop_test() { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[2]); // Path via node0 is channels {1, 3}. Limit them to 100 and 50 sats (total limit 50). @@ -3419,7 +3419,7 @@ mod tests { fn simple_mpp_route_test() { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known()); // We need a route consisting of 3 paths: @@ -3550,7 +3550,7 @@ mod tests { fn long_mpp_route_test() { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[3]).with_features(InvoiceFeatures::known()); // We need a route consisting of 3 paths: @@ -3712,7 +3712,7 @@ mod tests { fn mpp_cheaper_route_test() { let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[3]).with_features(InvoiceFeatures::known()); // This test checks that if we have two cheaper paths and one more expensive path, @@ -3879,7 +3879,7 @@ mod tests { // if the fee is not properly accounted for, the behavior is different. let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[3]).with_features(InvoiceFeatures::known()); // We need a route consisting of 2 paths: @@ -4048,7 +4048,7 @@ mod tests { // path finding we realize that we found more capacity than we need. let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known()); // We need a route consisting of 3 paths: @@ -4205,7 +4205,7 @@ mod tests { let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()); let net_graph_msg_handler = NetGraphMsgHandler::new(network_graph, None, Arc::clone(&logger)); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[6]); add_channel(&net_graph_msg_handler, &secp_ctx, &our_privkey, &privkeys[1], ChannelFeatures::from_le_bytes(id_to_feature_flags(6)), 6); @@ -4334,7 +4334,7 @@ mod tests { // we calculated fees on a higher value, resulting in us ignoring such paths. let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, _, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[2]); // We modify the graph to set the htlc_maximum of channel 2 to below the value we wish to @@ -4396,7 +4396,7 @@ mod tests { // resulting in us thinking there is no possible path, even if other paths exist. let (secp_ctx, net_graph_msg_handler, _, logger) = build_graph(); let (our_privkey, our_id, privkeys, nodes) = get_nodes(&secp_ctx); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[2]).with_features(InvoiceFeatures::known()); // We modify the graph to set the htlc_minimum of channel 2 and 4 as needed - channel 2 @@ -4463,7 +4463,7 @@ mod tests { let (_, our_id, _, nodes) = get_nodes(&secp_ctx); let logger = Arc::new(test_utils::TestLogger::new()); let network_graph = NetworkGraph::new(genesis_block(Network::Testnet).header.block_hash()); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let payee = Payee::new(nodes[0]).with_features(InvoiceFeatures::known()); { @@ -4504,7 +4504,7 @@ mod tests { let payee = Payee::new(nodes[6]).with_route_hints(last_hops(&nodes)); // Without penalizing each hop 100 msats, a longer path with lower fees is chosen. - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = get_route( &our_id, &payee, &net_graph_msg_handler.network_graph, None, 100, 42, Arc::clone(&logger), &scorer @@ -4517,7 +4517,7 @@ mod tests { // Applying a 100 msat penalty to each hop results in taking channels 7 and 10 to nodes[6] // from nodes[2] rather than channel 6, 11, and 8, even though the longer path is cheaper. - let scorer = Scorer::new(100); + let scorer = Scorer::with_fixed_penalty(100); let route = get_route( &our_id, &payee, &net_graph_msg_handler.network_graph, None, 100, 42, Arc::clone(&logger), &scorer @@ -4537,6 +4537,8 @@ mod tests { fn channel_penalty_msat(&self, short_channel_id: u64, _source: &NodeId, _target: &NodeId) -> u64 { if short_channel_id == self.short_channel_id { u64::max_value() } else { 0 } } + + fn payment_path_failed(&mut self, _path: &Vec, _short_channel_id: u64) {} } struct BadNodeScorer { @@ -4547,6 +4549,8 @@ mod tests { fn channel_penalty_msat(&self, _short_channel_id: u64, _source: &NodeId, target: &NodeId) -> u64 { if *target == self.node_id { u64::max_value() } else { 0 } } + + fn payment_path_failed(&mut self, _path: &Vec, _short_channel_id: u64) {} } #[test] @@ -4556,7 +4560,7 @@ mod tests { let payee = Payee::new(nodes[6]).with_route_hints(last_hops(&nodes)); // A path to nodes[6] exists when no penalties are applied to any channel. - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); let route = get_route( &our_id, &payee, &net_graph_msg_handler.network_graph, None, 100, 42, Arc::clone(&logger), &scorer @@ -4685,7 +4689,7 @@ mod tests { }, }; let graph = NetworkGraph::read(&mut d).unwrap(); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // First, get 100 (source, destination) pairs for which route-getting actually succeeds... let mut seed = random_init_seed() as usize; @@ -4716,7 +4720,7 @@ mod tests { }, }; let graph = NetworkGraph::read(&mut d).unwrap(); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // First, get 100 (source, destination) pairs for which route-getting actually succeeds... let mut seed = random_init_seed() as usize; @@ -4782,7 +4786,7 @@ mod benches { let mut d = test_utils::get_route_file().unwrap(); let graph = NetworkGraph::read(&mut d).unwrap(); let nodes = graph.read_only().nodes().clone(); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // First, get 100 (source, destination) pairs for which route-getting actually succeeds... let mut path_endpoints = Vec::new(); @@ -4817,7 +4821,7 @@ mod benches { let mut d = test_utils::get_route_file().unwrap(); let graph = NetworkGraph::read(&mut d).unwrap(); let nodes = graph.read_only().nodes().clone(); - let scorer = Scorer::new(0); + let scorer = Scorer::with_fixed_penalty(0); // First, get 100 (source, destination) pairs for which route-getting actually succeeds... let mut path_endpoints = Vec::new(); diff --git a/lightning/src/routing/scorer.rs b/lightning/src/routing/scorer.rs index e3f5c8679b6..d2b167675e0 100644 --- a/lightning/src/routing/scorer.rs +++ b/lightning/src/routing/scorer.rs @@ -19,7 +19,7 @@ //! # //! # use lightning::routing::network_graph::NetworkGraph; //! # use lightning::routing::router::{RouteParameters, find_route}; -//! # use lightning::routing::scorer::Scorer; +//! # use lightning::routing::scorer::{Scorer, ScoringParameters}; //! # use lightning::util::logger::{Logger, Record}; //! # use secp256k1::key::PublicKey; //! # @@ -30,11 +30,15 @@ //! # fn find_scored_route(payer: PublicKey, params: RouteParameters, network_graph: NetworkGraph) { //! # let logger = FakeLogger {}; //! # -//! // Use the default channel penalty. +//! // Use the default channel penalties. //! let scorer = Scorer::default(); //! -//! // Or use a custom channel penalty. -//! let scorer = Scorer::new(1_000); +//! // Or use custom channel penalties. +//! let scorer = Scorer::new(ScoringParameters { +//! base_penalty_msat: 1000, +//! failure_penalty_msat: 2 * 1024 * 1000, +//! ..ScoringParameters::default() +//! }); //! //! let route = find_route(&payer, ¶ms, &network_graph, None, &logger, &scorer); //! # } @@ -45,37 +49,133 @@ use routing; use routing::network_graph::NodeId; +use routing::router::RouteHop; + +use prelude::*; +#[cfg(not(feature = "no-std"))] +use core::time::Duration; +#[cfg(not(feature = "no-std"))] +use std::time::Instant; /// [`routing::Score`] implementation that provides reasonable default behavior. /// /// Used to apply a fixed penalty to each channel, thus avoiding long paths when shorter paths with -/// slightly higher fees are available. +/// slightly higher fees are available. May also further penalize failed channels. /// /// See [module-level documentation] for usage. /// /// [module-level documentation]: crate::routing::scorer pub struct Scorer { - base_penalty_msat: u64, + params: ScoringParameters, + #[cfg(not(feature = "no-std"))] + channel_failures: HashMap, + #[cfg(feature = "no-std")] + channel_failures: HashMap, +} + +/// Parameters for configuring [`Scorer`]. +pub struct ScoringParameters { + /// A fixed penalty in msats to apply to each channel. + pub base_penalty_msat: u64, + + /// A penalty in msats to apply to a channel upon failure. + /// + /// This may be reduced over time based on [`failure_penalty_half_life`]. + /// + /// [`failure_penalty_half_life`]: Self::failure_penalty_half_life + pub failure_penalty_msat: u64, + + /// The time needed before any accumulated channel failure penalties are cut in half. + #[cfg(not(feature = "no-std"))] + pub failure_penalty_half_life: Duration, } impl Scorer { - /// Creates a new scorer using `base_penalty_msat` as the channel penalty. - pub fn new(base_penalty_msat: u64) -> Self { - Self { base_penalty_msat } + /// Creates a new scorer using the given scoring parameters. + pub fn new(params: ScoringParameters) -> Self { + Self { + params, + channel_failures: HashMap::new(), + } + } + + /// Creates a new scorer using `penalty_msat` as a fixed channel penalty. + #[cfg(any(test, feature = "fuzztarget", feature = "_test_utils"))] + pub fn with_fixed_penalty(penalty_msat: u64) -> Self { + Self::new(ScoringParameters { + base_penalty_msat: penalty_msat, + failure_penalty_msat: 0, + #[cfg(not(feature = "no-std"))] + failure_penalty_half_life: Duration::from_secs(0), + }) + } + + #[cfg(not(feature = "no-std"))] + fn decay_from(&self, penalty_msat: u64, last_failure: &Instant) -> u64 { + decay_from(penalty_msat, last_failure, self.params.failure_penalty_half_life) } } impl Default for Scorer { - /// Creates a new scorer using 500 msat as the channel penalty. fn default() -> Self { - Scorer::new(500) + Scorer::new(ScoringParameters::default()) + } +} + +impl Default for ScoringParameters { + fn default() -> Self { + Self { + base_penalty_msat: 500, + failure_penalty_msat: 1024 * 1000, + #[cfg(not(feature = "no-std"))] + failure_penalty_half_life: Duration::from_secs(3600), + } } } impl routing::Score for Scorer { fn channel_penalty_msat( - &self, _short_channel_id: u64, _source: &NodeId, _target: &NodeId + &self, short_channel_id: u64, _source: &NodeId, _target: &NodeId ) -> u64 { - self.base_penalty_msat + #[cfg(not(feature = "no-std"))] + let failure_penalty_msat = match self.channel_failures.get(&short_channel_id) { + Some((penalty_msat, last_failure)) => self.decay_from(*penalty_msat, last_failure), + None => 0, + }; + #[cfg(feature = "no-std")] + let failure_penalty_msat = + self.channel_failures.get(&short_channel_id).copied().unwrap_or(0); + + self.params.base_penalty_msat + failure_penalty_msat + } + + fn payment_path_failed(&mut self, _path: &Vec, short_channel_id: u64) { + let failure_penalty_msat = self.params.failure_penalty_msat; + #[cfg(not(feature = "no-std"))] + { + let half_life = self.params.failure_penalty_half_life; + self.channel_failures + .entry(short_channel_id) + .and_modify(|(penalty_msat, last_failure)| { + let decayed_penalty = decay_from(*penalty_msat, last_failure, half_life); + *penalty_msat = decayed_penalty + failure_penalty_msat; + *last_failure = Instant::now(); + }) + .or_insert_with(|| (failure_penalty_msat, Instant::now())); + } + #[cfg(feature = "no-std")] + self.channel_failures + .entry(short_channel_id) + .and_modify(|penalty_msat| *penalty_msat += failure_penalty_msat) + .or_insert(failure_penalty_msat); + } +} + +#[cfg(not(feature = "no-std"))] +fn decay_from(penalty_msat: u64, last_failure: &Instant, half_life: Duration) -> u64 { + let decays = last_failure.elapsed().as_secs().checked_div(half_life.as_secs()); + match decays { + Some(decays) => penalty_msat >> decays, + None => 0, } }