Skip to content

Commit 1e61b38

Browse files
committed
f - Replace RwLock<S> with S: LockableScore
1 parent 5b3a33a commit 1e61b38

File tree

2 files changed

+77
-58
lines changed

2 files changed

+77
-58
lines changed

lightning-invoice/src/payment.rs

Lines changed: 41 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
//! # use lightning_invoice::payment::{InvoicePayer, Payer, RetryAttempts, Router};
4040
//! # use secp256k1::key::PublicKey;
4141
//! # use std::ops::Deref;
42+
//! # use std::sync::Mutex;
4243
//! #
4344
//! # struct FakeEventProvider {}
4445
//! # impl EventsProvider for FakeEventProvider {
@@ -88,9 +89,9 @@
8889
//! };
8990
//! # let payer = FakePayer {};
9091
//! # let router = FakeRouter {};
91-
//! # let scorer = FakeScorer {};
92+
//! # let scorer = Mutex::new(FakeScorer {});
9293
//! # let logger = FakeLogger {};
93-
//! let invoice_payer = InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
94+
//! let invoice_payer = InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
9495
//!
9596
//! let invoice = "...";
9697
//! let invoice = invoice.parse::<Invoice>().unwrap();
@@ -117,6 +118,7 @@ use lightning::ln::{PaymentHash, PaymentSecret};
117118
use lightning::ln::channelmanager::{ChannelDetails, PaymentId, PaymentSendFailure};
118119
use lightning::ln::msgs::LightningError;
119120
use lightning::routing;
121+
use lightning::routing::Score;
120122
use lightning::routing::router::{Payee, Route, RouteParameters};
121123
use lightning::util::events::{Event, EventHandler};
122124
use lightning::util::logger::Logger;
@@ -125,21 +127,21 @@ use secp256k1::key::PublicKey;
125127

126128
use std::collections::hash_map::{self, HashMap};
127129
use std::ops::Deref;
128-
use std::sync::{Mutex, RwLock};
130+
use std::sync::Mutex;
129131
use std::time::{Duration, SystemTime};
130132

131133
/// A utility for paying [`Invoice]`s.
132134
pub struct InvoicePayer<P: Deref, R, S, L: Deref, E>
133135
where
134136
P::Target: Payer,
135137
R: Router,
136-
S: routing::Score,
138+
S: for <'a> routing::LockableScore<'a>,
137139
L::Target: Logger,
138140
E: EventHandler,
139141
{
140142
payer: P,
141143
router: R,
142-
scorer: RwLock<S>,
144+
scorer: S,
143145
logger: L,
144146
event_handler: E,
145147
payment_cache: Mutex<HashMap<PaymentHash, usize>>,
@@ -187,22 +189,11 @@ pub enum PaymentError {
187189
Sending(PaymentSendFailure),
188190
}
189191

190-
/// A read-only version of the scorer.
191-
pub struct ReadOnlyScorer<'a, S: routing::Score>(std::sync::RwLockReadGuard<'a, S>);
192-
193-
impl<'a, S: routing::Score> Deref for ReadOnlyScorer<'a, S> {
194-
type Target = S;
195-
196-
fn deref(&self) -> &Self::Target {
197-
&*self.0
198-
}
199-
}
200-
201192
impl<P: Deref, R, S, L: Deref, E> InvoicePayer<P, R, S, L, E>
202193
where
203194
P::Target: Payer,
204195
R: Router,
205-
S: routing::Score,
196+
S: for <'a> routing::LockableScore<'a>,
206197
L::Target: Logger,
207198
E: EventHandler,
208199
{
@@ -216,22 +207,14 @@ where
216207
Self {
217208
payer,
218209
router,
219-
scorer: RwLock::new(scorer),
210+
scorer,
220211
logger,
221212
event_handler,
222213
payment_cache: Mutex::new(HashMap::new()),
223214
retry_attempts,
224215
}
225216
}
226217

227-
/// Returns a read-only reference to the parameterized [`routing::Score`].
228-
///
229-
/// Useful if the scorer needs to be persisted. Be sure to drop the returned guard immediately
230-
/// after use since retrying failed payment paths require write access.
231-
pub fn scorer(&'_ self) -> ReadOnlyScorer<'_, S> {
232-
ReadOnlyScorer(self.scorer.read().unwrap())
233-
}
234-
235218
/// Pays the given [`Invoice`], caching it for later use in case a retry is needed.
236219
pub fn pay_invoice(&self, invoice: &Invoice) -> Result<PaymentId, PaymentError> {
237220
if invoice.amount_milli_satoshis().is_none() {
@@ -278,7 +261,7 @@ where
278261
&payer,
279262
&params,
280263
Some(&first_hops.iter().collect::<Vec<_>>()),
281-
&*self.scorer.read().unwrap(),
264+
&self.scorer.lock(),
282265
).map_err(|e| PaymentError::Routing(e))?;
283266

284267
let payment_hash = PaymentHash(invoice.payment_hash().clone().into_inner());
@@ -299,7 +282,7 @@ where
299282
let first_hops = self.payer.first_hops();
300283
let route = self.router.find_route(
301284
&payer, &params, Some(&first_hops.iter().collect::<Vec<_>>()),
302-
&*self.scorer.read().unwrap()
285+
&self.scorer.lock()
303286
).map_err(|e| PaymentError::Routing(e))?;
304287
self.payer.retry_payment(&route, payment_id).map_err(|e| PaymentError::Sending(e))
305288
}
@@ -326,7 +309,7 @@ impl<P: Deref, R, S, L: Deref, E> EventHandler for InvoicePayer<P, R, S, L, E>
326309
where
327310
P::Target: Payer,
328311
R: Router,
329-
S: routing::Score,
312+
S: for <'a> routing::LockableScore<'a>,
330313
L::Target: Logger,
331314
E: EventHandler,
332315
{
@@ -336,7 +319,7 @@ where
336319
payment_id, payment_hash, rejected_by_dest, path, short_channel_id, retry, ..
337320
} => {
338321
if let Some(short_channel_id) = short_channel_id {
339-
self.scorer.write().unwrap().payment_path_failed(path, *short_channel_id);
322+
self.scorer.lock().payment_path_failed(path, *short_channel_id);
340323
}
341324

342325
let mut payment_cache = self.payment_cache.lock().unwrap();
@@ -468,10 +451,10 @@ mod tests {
468451

469452
let payer = TestPayer::new();
470453
let router = TestRouter {};
471-
let scorer = TestScorer::new();
454+
let scorer = Mutex::new(TestScorer::new());
472455
let logger = TestLogger::new();
473456
let invoice_payer =
474-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0));
457+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
475458

476459
let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
477460
assert_eq!(*payer.attempts.borrow(), 1);
@@ -497,10 +480,10 @@ mod tests {
497480
.expect_value_msat(final_value_msat)
498481
.expect_value_msat(final_value_msat / 2);
499482
let router = TestRouter {};
500-
let scorer = TestScorer::new();
483+
let scorer = Mutex::new(TestScorer::new());
501484
let logger = TestLogger::new();
502485
let invoice_payer =
503-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
486+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
504487

505488
let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
506489
assert_eq!(*payer.attempts.borrow(), 1);
@@ -538,10 +521,10 @@ mod tests {
538521

539522
let payer = TestPayer::new();
540523
let router = TestRouter {};
541-
let scorer = TestScorer::new();
524+
let scorer = Mutex::new(TestScorer::new());
542525
let logger = TestLogger::new();
543526
let invoice_payer =
544-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
527+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
545528

546529
let payment_id = Some(PaymentId([1; 32]));
547530
let event = Event::PaymentPathFailed {
@@ -583,10 +566,10 @@ mod tests {
583566
.expect_value_msat(final_value_msat / 2)
584567
.expect_value_msat(final_value_msat / 2);
585568
let router = TestRouter {};
586-
let scorer = TestScorer::new();
569+
let scorer = Mutex::new(TestScorer::new());
587570
let logger = TestLogger::new();
588571
let invoice_payer =
589-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
572+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
590573

591574
let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
592575
assert_eq!(*payer.attempts.borrow(), 1);
@@ -633,10 +616,10 @@ mod tests {
633616

634617
let payer = TestPayer::new();
635618
let router = TestRouter {};
636-
let scorer = TestScorer::new();
619+
let scorer = Mutex::new(TestScorer::new());
637620
let logger = TestLogger::new();
638621
let invoice_payer =
639-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
622+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
640623

641624
let payment_preimage = PaymentPreimage([1; 32]);
642625
let invoice = invoice(payment_preimage);
@@ -665,10 +648,10 @@ mod tests {
665648

666649
let payer = TestPayer::new();
667650
let router = TestRouter {};
668-
let scorer = TestScorer::new();
651+
let scorer = Mutex::new(TestScorer::new());
669652
let logger = TestLogger::new();
670653
let invoice_payer =
671-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
654+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
672655

673656
let payment_preimage = PaymentPreimage([1; 32]);
674657
let invoice = expired_invoice(payment_preimage);
@@ -703,10 +686,10 @@ mod tests {
703686
.fails_on_attempt(2)
704687
.expect_value_msat(final_value_msat);
705688
let router = TestRouter {};
706-
let scorer = TestScorer::new();
689+
let scorer = Mutex::new(TestScorer::new());
707690
let logger = TestLogger::new();
708691
let invoice_payer =
709-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
692+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
710693

711694
let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
712695
assert_eq!(*payer.attempts.borrow(), 1);
@@ -733,10 +716,10 @@ mod tests {
733716

734717
let payer = TestPayer::new();
735718
let router = TestRouter {};
736-
let scorer = TestScorer::new();
719+
let scorer = Mutex::new(TestScorer::new());
737720
let logger = TestLogger::new();
738721
let invoice_payer =
739-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
722+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
740723

741724
let payment_preimage = PaymentPreimage([1; 32]);
742725
let invoice = invoice(payment_preimage);
@@ -765,10 +748,10 @@ mod tests {
765748

766749
let payer = TestPayer::new();
767750
let router = TestRouter {};
768-
let scorer = TestScorer::new();
751+
let scorer = Mutex::new(TestScorer::new());
769752
let logger = TestLogger::new();
770753
let invoice_payer =
771-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0));
754+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
772755

773756
let payment_preimage = PaymentPreimage([1; 32]);
774757
let invoice = invoice(payment_preimage);
@@ -806,10 +789,10 @@ mod tests {
806789
fn fails_paying_invoice_with_routing_errors() {
807790
let payer = TestPayer::new();
808791
let router = FailingRouter {};
809-
let scorer = TestScorer::new();
792+
let scorer = Mutex::new(TestScorer::new());
810793
let logger = TestLogger::new();
811794
let invoice_payer =
812-
InvoicePayer::new(&payer, router, scorer, &logger, |_: &_| {}, RetryAttempts(0));
795+
InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
813796

814797
let payment_preimage = PaymentPreimage([1; 32]);
815798
let invoice = invoice(payment_preimage);
@@ -824,10 +807,10 @@ mod tests {
824807
fn fails_paying_invoice_with_sending_errors() {
825808
let payer = TestPayer::new().fails_on_attempt(1);
826809
let router = TestRouter {};
827-
let scorer = TestScorer::new();
810+
let scorer = Mutex::new(TestScorer::new());
828811
let logger = TestLogger::new();
829812
let invoice_payer =
830-
InvoicePayer::new(&payer, router, scorer, &logger, |_: &_| {}, RetryAttempts(0));
813+
InvoicePayer::new(&payer, router, &scorer, &logger, |_: &_| {}, RetryAttempts(0));
831814

832815
let payment_preimage = PaymentPreimage([1; 32]);
833816
let invoice = invoice(payment_preimage);
@@ -850,10 +833,10 @@ mod tests {
850833

851834
let payer = TestPayer::new().expect_value_msat(final_value_msat);
852835
let router = TestRouter {};
853-
let scorer = TestScorer::new();
836+
let scorer = Mutex::new(TestScorer::new());
854837
let logger = TestLogger::new();
855838
let invoice_payer =
856-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0));
839+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
857840

858841
let payment_id =
859842
Some(invoice_payer.pay_zero_value_invoice(&invoice, final_value_msat).unwrap());
@@ -873,10 +856,10 @@ mod tests {
873856

874857
let payer = TestPayer::new();
875858
let router = TestRouter {};
876-
let scorer = TestScorer::new();
859+
let scorer = Mutex::new(TestScorer::new());
877860
let logger = TestLogger::new();
878861
let invoice_payer =
879-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(0));
862+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(0));
880863

881864
let payment_preimage = PaymentPreimage([1; 32]);
882865
let invoice = invoice(payment_preimage);
@@ -903,10 +886,10 @@ mod tests {
903886

904887
let payer = TestPayer::new();
905888
let router = TestRouter {};
906-
let scorer = TestScorer::new().expect_channel_failure(short_channel_id.unwrap());
889+
let scorer = Mutex::new(TestScorer::new().expect_channel_failure(short_channel_id.unwrap()));
907890
let logger = TestLogger::new();
908891
let invoice_payer =
909-
InvoicePayer::new(&payer, router, scorer, &logger, event_handler, RetryAttempts(2));
892+
InvoicePayer::new(&payer, router, &scorer, &logger, event_handler, RetryAttempts(2));
910893

911894
let payment_id = Some(invoice_payer.pay_invoice(&invoice).unwrap());
912895
let event = Event::PaymentPathFailed {

lightning/src/routing/mod.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ use routing::network_graph::NodeId;
1717
use routing::router::RouteHop;
1818

1919
use prelude::*;
20+
use core::ops::{Deref, DerefMut};
21+
use sync::{Mutex, MutexGuard};
2022

2123
/// An interface used to score payment channels for path finding.
2224
///
@@ -29,3 +31,37 @@ pub trait Score {
2931
/// Handles updating channel penalties after failing to route through a channel.
3032
fn payment_path_failed(&mut self, path: &Vec<RouteHop>, short_channel_id: u64);
3133
}
34+
35+
/// A scorer that is accessed under a lock.
36+
///
37+
/// Needed so that calls to [`Score::channel_penalty_msat`] in [`find_route`] can be made while
38+
/// having shared ownership of a scorer but without requiring internal locking in [`Score`]
39+
/// implementations. Internal locking would be detrimental to route finding performance and could
40+
/// result in [`Score::channel_penalty_msat`] returning a different value for the same channel.
41+
///
42+
/// [`find_route`]: crate::routing::router::find_route
43+
pub trait LockableScore<'a> {
44+
/// The locked [`Score`] type.
45+
type Locked: 'a + Score;
46+
47+
/// Returns the locked scorer.
48+
fn lock(&'a self) -> Self::Locked;
49+
}
50+
51+
impl<'a, S: 'a + Score, T: Deref<Target=Mutex<S>>> LockableScore<'a> for T {
52+
type Locked = MutexGuard<'a, S>;
53+
54+
fn lock(&'a self) -> MutexGuard<'a, S> {
55+
self.deref().lock().unwrap()
56+
}
57+
}
58+
59+
impl<'a, S: Score> Score for MutexGuard<'a, S> {
60+
fn channel_penalty_msat(&self, short_channel_id: u64, source: &NodeId, target: &NodeId) -> u64 {
61+
self.deref().channel_penalty_msat(short_channel_id, source, target)
62+
}
63+
64+
fn payment_path_failed(&mut self, path: &Vec<RouteHop>, short_channel_id: u64) {
65+
self.deref_mut().payment_path_failed(path, short_channel_id)
66+
}
67+
}

0 commit comments

Comments
 (0)