Skip to content

Commit e8854a9

Browse files
Remove unnecessary aquiring of the channel_state lock
1 parent df12df3 commit e8854a9

File tree

1 file changed

+52
-51
lines changed

1 file changed

+52
-51
lines changed

lightning/src/ln/channelmanager.rs

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1887,7 +1887,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
18871887

18881888
for htlc_source in failed_htlcs.drain(..) {
18891889
let receiver = HTLCDestination::NextHopChannel { node_id: Some(*counterparty_node_id), channel_id: *channel_id };
1890-
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
1890+
self.fail_htlc_backwards_internal(htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
18911891
}
18921892

18931893
let _ = handle_error!(self, result, *counterparty_node_id);
@@ -1945,7 +1945,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
19451945
for htlc_source in failed_htlcs.drain(..) {
19461946
let (source, payment_hash, counterparty_node_id, channel_id) = htlc_source;
19471947
let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id), channel_id: channel_id };
1948-
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
1948+
self.fail_htlc_backwards_internal(source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
19491949
}
19501950
if let Some((funding_txo, monitor_update)) = monitor_update_option {
19511951
// There isn't anything we can do if we get an update failure - we're already
@@ -3005,13 +3005,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
30053005
let mut phantom_receives: Vec<(u64, OutPoint, Vec<(PendingHTLCInfo, u64)>)> = Vec::new();
30063006
let mut handle_errors = Vec::new();
30073007
{
3008-
let mut channel_state_lock = self.channel_state.lock().unwrap();
3009-
let channel_state = &mut *channel_state_lock;
3010-
30113008
let mut forward_htlcs = HashMap::new();
30123009
mem::swap(&mut forward_htlcs, &mut self.forward_htlcs.lock().unwrap());
30133010

30143011
for (short_chan_id, mut pending_forwards) in forward_htlcs {
3012+
let mut channel_state_lock = self.channel_state.lock().unwrap();
3013+
let channel_state = &mut *channel_state_lock;
30153014
if short_chan_id != 0 {
30163015
let forward_chan_id = match channel_state.short_to_chan_info.get(&short_chan_id) {
30173016
Some((_cp_id, chan_id)) => chan_id.clone(),
@@ -3409,7 +3408,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
34093408
}
34103409

34113410
for (htlc_source, payment_hash, failure_reason, destination) in failed_forwards.drain(..) {
3412-
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source, &payment_hash, failure_reason, destination);
3411+
self.fail_htlc_backwards_internal(htlc_source, &payment_hash, failure_reason, destination);
34133412
}
34143413
self.forward_htlcs(&mut phantom_receives);
34153414

@@ -3633,7 +3632,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
36333632

36343633
for htlc_source in timed_out_mpp_htlcs.drain(..) {
36353634
let receiver = HTLCDestination::FailedPayment { payment_hash: htlc_source.1 };
3636-
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), HTLCSource::PreviousHopData(htlc_source.0.clone()), &htlc_source.1, HTLCFailReason::Reason { failure_code: 23, data: Vec::new() }, receiver );
3635+
self.fail_htlc_backwards_internal(HTLCSource::PreviousHopData(htlc_source.0.clone()), &htlc_source.1, HTLCFailReason::Reason { failure_code: 23, data: Vec::new() }, receiver );
36373636
}
36383637

36393638
for (err, counterparty_node_id) in handle_errors.drain(..) {
@@ -3659,15 +3658,16 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
36593658
pub fn fail_htlc_backwards(&self, payment_hash: &PaymentHash) {
36603659
let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
36613660

3662-
let mut channel_state = Some(self.channel_state.lock().unwrap());
3663-
let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(payment_hash);
3661+
let removed_source = {
3662+
let mut channel_state = self.channel_state.lock().unwrap();
3663+
channel_state.claimable_htlcs.remove(payment_hash)
3664+
};
36643665
if let Some((_, mut sources)) = removed_source {
36653666
for htlc in sources.drain(..) {
3666-
if channel_state.is_none() { channel_state = Some(self.channel_state.lock().unwrap()); }
36673667
let mut htlc_msat_height_data = byte_utils::be64_to_array(htlc.value).to_vec();
36683668
htlc_msat_height_data.extend_from_slice(&byte_utils::be32_to_array(
36693669
self.best_block.read().unwrap().height()));
3670-
self.fail_htlc_backwards_internal(channel_state.take().unwrap(),
3670+
self.fail_htlc_backwards_internal(
36713671
HTLCSource::PreviousHopData(htlc.prev_hop), payment_hash,
36723672
HTLCFailReason::Reason { failure_code: 0x4000 | 15, data: htlc_msat_height_data },
36733673
HTLCDestination::FailedPayment { payment_hash: *payment_hash });
@@ -3730,27 +3730,31 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
37303730
counterparty_node_id: &PublicKey
37313731
) {
37323732
for (htlc_src, payment_hash) in htlcs_to_fail.drain(..) {
3733-
let mut channel_state = self.channel_state.lock().unwrap();
37343733
let (failure_code, onion_failure_data) =
3735-
match channel_state.by_id.entry(channel_id) {
3734+
match self.channel_state.lock().unwrap().by_id.entry(channel_id) {
37363735
hash_map::Entry::Occupied(chan_entry) => {
37373736
self.get_htlc_inbound_temp_fail_err_and_data(0x1000|7, &chan_entry.get())
37383737
},
37393738
hash_map::Entry::Vacant(_) => (0x4000|10, Vec::new())
37403739
};
37413740

37423741
let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id.clone()), channel_id };
3743-
self.fail_htlc_backwards_internal(channel_state, htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data }, receiver);
3742+
self.fail_htlc_backwards_internal(htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data }, receiver);
37443743
}
37453744
}
37463745

37473746
/// Fails an HTLC backwards to the sender of it to us.
3748-
/// Note that while we take a channel_state lock as input, we do *not* assume consistency here.
3749-
/// There are several callsites that do stupid things like loop over a list of payment_hashes
3750-
/// to fail and take the channel_state lock for each iteration (as we take ownership and may
3751-
/// drop it). In other words, no assumptions are made that entries in claimable_htlcs point to
3752-
/// still-available channels.
3753-
fn fail_htlc_backwards_internal(&self, mut channel_state_lock: MutexGuard<ChannelHolder<Signer>>, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason, destination: HTLCDestination) {
3747+
/// Note that we do not assume that channels corresponding to failed HTLCs are still available.
3748+
fn fail_htlc_backwards_internal(&self, source: HTLCSource, payment_hash: &PaymentHash, onion_error: HTLCFailReason,destination: HTLCDestination) {
3749+
#[cfg(debug_assertions)]
3750+
{
3751+
// Ensure that the `channel_state` lock is not held when calling this function.
3752+
// This ensures that future code doesn't introduce a lock_order requirement for
3753+
// `forward_htlcs` to be locked after the `channel_state` lock, which calling this
3754+
// function with the `channel_state` locked would.
3755+
assert!(self.channel_state.try_lock().is_ok());
3756+
}
3757+
37543758
//TODO: There is a timing attack here where if a node fails an HTLC back to us they can
37553759
//identify whether we sent it or not based on the (I presume) very different runtime
37563760
//between the branches here. We should make this async and move it into the forward HTLCs
@@ -3789,7 +3793,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
37893793
log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0));
37903794
return;
37913795
}
3792-
mem::drop(channel_state_lock);
37933796
let mut retry = if let Some(payment_params_data) = payment_params {
37943797
let path_last_hop = path.last().expect("Outbound payments must have had a valid path");
37953798
Some(RouteParameters {
@@ -3923,7 +3926,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
39233926
}
39243927
}
39253928
mem::drop(forward_htlcs);
3926-
mem::drop(channel_state_lock);
39273929
let mut pending_events = self.pending_events.lock().unwrap();
39283930
if let Some(time) = forward_event {
39293931
pending_events.push(events::Event::PendingHTLCsForwardable {
@@ -3961,8 +3963,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
39613963

39623964
let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier);
39633965

3964-
let mut channel_state = Some(self.channel_state.lock().unwrap());
3965-
let removed_source = channel_state.as_mut().unwrap().claimable_htlcs.remove(&payment_hash);
3966+
let removed_source = self.channel_state.lock().unwrap().claimable_htlcs.remove(&payment_hash);
39663967
if let Some((payment_purpose, mut sources)) = removed_source {
39673968
assert!(!sources.is_empty());
39683969

@@ -3980,8 +3981,12 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
39803981
let mut claimable_amt_msat = 0;
39813982
let mut expected_amt_msat = None;
39823983
let mut valid_mpp = true;
3984+
let mut errs = Vec::new();
3985+
let mut claimed_any_htlcs = false;
3986+
let mut channel_state_lock = self.channel_state.lock().unwrap();
3987+
let channel_state = &mut *channel_state_lock;
39833988
for htlc in sources.iter() {
3984-
if let None = channel_state.as_ref().unwrap().short_to_chan_info.get(&htlc.prev_hop.short_channel_id) {
3989+
if let None = channel_state.short_to_chan_info.get(&htlc.prev_hop.short_channel_id) {
39853990
valid_mpp = false;
39863991
break;
39873992
}
@@ -4014,21 +4019,9 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
40144019
expected_amt_msat.unwrap(), claimable_amt_msat);
40154020
return;
40164021
}
4017-
4018-
let mut errs = Vec::new();
4019-
let mut claimed_any_htlcs = false;
4020-
for htlc in sources.drain(..) {
4021-
if !valid_mpp {
4022-
if channel_state.is_none() { channel_state = Some(self.channel_state.lock().unwrap()); }
4023-
let mut htlc_msat_height_data = byte_utils::be64_to_array(htlc.value).to_vec();
4024-
htlc_msat_height_data.extend_from_slice(&byte_utils::be32_to_array(
4025-
self.best_block.read().unwrap().height()));
4026-
self.fail_htlc_backwards_internal(channel_state.take().unwrap(),
4027-
HTLCSource::PreviousHopData(htlc.prev_hop), &payment_hash,
4028-
HTLCFailReason::Reason { failure_code: 0x4000|15, data: htlc_msat_height_data },
4029-
HTLCDestination::FailedPayment { payment_hash } );
4030-
} else {
4031-
match self.claim_funds_from_hop(channel_state.as_mut().unwrap(), htlc.prev_hop, payment_preimage) {
4022+
if valid_mpp {
4023+
for htlc in sources.drain(..) {
4024+
match self.claim_funds_from_hop(&mut channel_state_lock, htlc.prev_hop, payment_preimage) {
40324025
ClaimFundsFromHop::MonitorUpdateFail(pk, err, _) => {
40334026
if let msgs::ErrorAction::IgnoreError = err.err.action {
40344027
// We got a temporary failure updating monitor, but will claim the
@@ -4048,6 +4041,18 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
40484041
}
40494042
}
40504043
}
4044+
mem::drop(channel_state_lock);
4045+
if !valid_mpp {
4046+
for htlc in sources.drain(..) {
4047+
let mut htlc_msat_height_data = byte_utils::be64_to_array(htlc.value).to_vec();
4048+
htlc_msat_height_data.extend_from_slice(&byte_utils::be32_to_array(
4049+
self.best_block.read().unwrap().height()));
4050+
self.fail_htlc_backwards_internal(
4051+
HTLCSource::PreviousHopData(htlc.prev_hop), &payment_hash,
4052+
HTLCFailReason::Reason { failure_code: 0x4000|15, data: htlc_msat_height_data },
4053+
HTLCDestination::FailedPayment { payment_hash } );
4054+
}
4055+
}
40514056

40524057
if claimed_any_htlcs {
40534058
self.pending_events.lock().unwrap().push(events::Event::PaymentClaimed {
@@ -4057,10 +4062,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
40574062
});
40584063
}
40594064

4060-
// Now that we've done the entire above loop in one lock, we can handle any errors
4061-
// which were generated.
4062-
channel_state.take();
4063-
4065+
// Now we can handle any errors which were generated.
40644066
for (counterparty_node_id, err) in errs.drain(..) {
40654067
let res: Result<(), _> = Err(err);
40664068
let _ = handle_error!(self, res, counterparty_node_id);
@@ -4307,7 +4309,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
43074309
self.finalize_claims(finalized_claims);
43084310
for failure in pending_failures.drain(..) {
43094311
let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id), channel_id: funding_txo.to_channel_id() };
4310-
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), failure.0, &failure.1, failure.2, receiver);
4312+
self.fail_htlc_backwards_internal(failure.0, &failure.1, failure.2, receiver);
43114313
}
43124314
}
43134315

@@ -4667,7 +4669,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
46674669
};
46684670
for htlc_source in dropped_htlcs.drain(..) {
46694671
let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id.clone()), channel_id: msg.channel_id };
4670-
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
4672+
self.fail_htlc_backwards_internal(htlc_source.0, &htlc_source.1, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
46714673
}
46724674

46734675
let _ = handle_error!(self, result, *counterparty_node_id);
@@ -4869,7 +4871,6 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
48694871
for &mut (prev_short_channel_id, prev_funding_outpoint, ref mut pending_forwards) in per_source_pending_forwards {
48704872
let mut forward_event = None;
48714873
if !pending_forwards.is_empty() {
4872-
let mut channel_state = self.channel_state.lock().unwrap();
48734874
let mut forward_htlcs = self.forward_htlcs.lock().unwrap();
48744875
if forward_htlcs.is_empty() {
48754876
forward_event = Some(Duration::from_millis(MIN_HTLC_RELAY_HOLDING_CELL_MILLIS))
@@ -4956,7 +4957,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
49564957
{
49574958
for failure in pending_failures.drain(..) {
49584959
let receiver = HTLCDestination::NextHopChannel { node_id: Some(*counterparty_node_id), channel_id: channel_outpoint.to_channel_id() };
4959-
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), failure.0, &failure.1, failure.2, receiver);
4960+
self.fail_htlc_backwards_internal(failure.0, &failure.1, failure.2, receiver);
49604961
}
49614962
self.forward_htlcs(&mut [(short_channel_id, channel_outpoint, pending_forwards)]);
49624963
self.finalize_claims(finalized_claim_htlcs);
@@ -5114,7 +5115,7 @@ impl<Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> ChannelMana
51145115
} else {
51155116
log_trace!(self.logger, "Failing HTLC with hash {} from our monitor", log_bytes!(htlc_update.payment_hash.0));
51165117
let receiver = HTLCDestination::NextHopChannel { node_id: counterparty_node_id, channel_id: funding_outpoint.to_channel_id() };
5117-
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
5118+
self.fail_htlc_backwards_internal(htlc_update.source, &htlc_update.payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
51185119
}
51195120
},
51205121
MonitorEvent::CommitmentTxConfirmed(funding_outpoint) |
@@ -5848,7 +5849,7 @@ where
58485849
self.handle_init_event_channel_failures(failed_channels);
58495850

58505851
for (source, payment_hash, reason, destination) in timed_out_htlcs.drain(..) {
5851-
self.fail_htlc_backwards_internal(self.channel_state.lock().unwrap(), source, &payment_hash, reason, destination);
5852+
self.fail_htlc_backwards_internal(source, &payment_hash, reason, destination);
58525853
}
58535854
}
58545855

@@ -7210,7 +7211,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref>
72107211
for htlc_source in failed_htlcs.drain(..) {
72117212
let (source, payment_hash, counterparty_node_id, channel_id) = htlc_source;
72127213
let receiver = HTLCDestination::NextHopChannel { node_id: Some(counterparty_node_id), channel_id };
7213-
channel_manager.fail_htlc_backwards_internal(channel_manager.channel_state.lock().unwrap(), source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
7214+
channel_manager.fail_htlc_backwards_internal(source, &payment_hash, HTLCFailReason::Reason { failure_code: 0x4000 | 8, data: Vec::new() }, receiver);
72147215
}
72157216

72167217
//TODO: Broadcast channel update for closed channels, but only after we've made a

0 commit comments

Comments
 (0)