diff --git a/lightning/src/onion_message/blinded_route.rs b/lightning/src/onion_message/blinded_route.rs index e47c77de354..83113fc79ba 100644 --- a/lightning/src/onion_message/blinded_route.rs +++ b/lightning/src/onion_message/blinded_route.rs @@ -13,9 +13,11 @@ use bitcoin::secp256k1::{self, PublicKey, Secp256k1, SecretKey}; use chain::keysinterface::KeysInterface; use super::utils; +use ::get_control_tlv_length; use ln::msgs::DecodeError; use util::chacha20poly1305rfc::ChaChaPolyWriteAdapter; use util::ser::{Readable, VecWriter, Writeable, Writer}; +use super::packet::{ControlTlvs, Padding}; use io; use prelude::*; @@ -54,10 +56,11 @@ impl BlindedRoute { /// will be the destination node. /// /// Errors if less than two hops are provided or if `node_pk`(s) are invalid. - // TODO: make all payloads the same size with padding + add dummy hops - pub fn new - (node_pks: &[PublicKey], keys_manager: &K, secp_ctx: &Secp256k1) -> Result - { + // TODO: Add dummy hops + pub fn new ( + node_pks: &[PublicKey], keys_manager: &K, secp_ctx: &Secp256k1, + include_next_blinding_override_padding: bool + ) -> Result { if node_pks.len() < 2 { return Err(()) } let blinding_secret_bytes = keys_manager.get_secure_random_bytes(); let blinding_secret = SecretKey::from_slice(&blinding_secret_bytes[..]).expect("RNG is busted"); @@ -66,16 +69,18 @@ impl BlindedRoute { Ok(BlindedRoute { introduction_node_id, blinding_point: PublicKey::from_secret_key(secp_ctx, &blinding_secret), - blinded_hops: blinded_hops(secp_ctx, node_pks, &blinding_secret).map_err(|_| ())?, + blinded_hops: blinded_hops(secp_ctx, node_pks, &blinding_secret, include_next_blinding_override_padding).map_err(|_| ())?, }) } } /// Construct blinded hops for the given `unblinded_path`. fn blinded_hops( - secp_ctx: &Secp256k1, unblinded_path: &[PublicKey], session_priv: &SecretKey + secp_ctx: &Secp256k1, unblinded_path: &[PublicKey], session_priv: &SecretKey, + include_next_blinding_override_padding: bool ) -> Result, secp256k1::Error> { let mut blinded_hops = Vec::with_capacity(unblinded_path.len()); + let max_length = get_control_tlv_length!(true, include_next_blinding_override_padding); let mut prev_ss_and_blinded_node_id = None; utils::construct_keys_callback(secp_ctx, unblinded_path, None, session_priv, |blinded_node_id, _, _, encrypted_payload_ss, unblinded_pk, _| { @@ -84,6 +89,7 @@ fn blinded_hops( let payload = ForwardTlvs { next_node_id: pk, next_blinding_override: None, + total_length: max_length, }; blinded_hops.push(BlindedHop { blinded_node_id: prev_blinded_node_id, @@ -95,7 +101,7 @@ fn blinded_hops( })?; if let Some((final_ss, final_blinded_node_id)) = prev_ss_and_blinded_node_id { - let final_payload = ReceiveTlvs { path_id: None }; + let final_payload = ReceiveTlvs { path_id: None, total_length: max_length, }; blinded_hops.push(BlindedHop { blinded_node_id: final_blinded_node_id, encrypted_payload: encrypt_payload(final_payload, final_ss), @@ -150,28 +156,36 @@ impl_writeable!(BlindedHop, { /// TLVs to encode in an intermediate onion message packet's hop data. When provided in a blinded /// route, they are encoded into [`BlindedHop::encrypted_payload`]. +#[derive(Clone, Copy)] pub(crate) struct ForwardTlvs { /// The node id of the next hop in the onion message's path. pub(super) next_node_id: PublicKey, /// Senders to a blinded route use this value to concatenate the route they find to the /// introduction node with the blinded route. pub(super) next_blinding_override: Option, + /// The length the tlv should have when it's serialized, with padding included if needed. + /// Used to ensure that all control tlvs in a blinded route have the same length. + pub(super) total_length: u16, } /// Similar to [`ForwardTlvs`], but these TLVs are for the final node. +#[derive(Clone, Copy)] pub(crate) struct ReceiveTlvs { /// If `path_id` is `Some`, it is used to identify the blinded route that this onion message is /// sending to. This is useful for receivers to check that said blinded route is being used in /// the right context. pub(super) path_id: Option<[u8; 32]>, + /// The length the tlv should have when it's serialized, with padding included if needed. + /// Used to ensure that all control tlvs in a blinded route have the same length. + pub(super) total_length: u16, } impl Writeable for ForwardTlvs { fn write(&self, writer: &mut W) -> Result<(), io::Error> { - // TODO: write padding encode_tlv_stream!(writer, { + (1, Padding::new_from_tlv(ControlTlvs::Forward(*self)), option), (4, self.next_node_id, required), - (8, self.next_blinding_override, option) + (8, self.next_blinding_override, option), }); Ok(()) } @@ -179,10 +193,134 @@ impl Writeable for ForwardTlvs { impl Writeable for ReceiveTlvs { fn write(&self, writer: &mut W) -> Result<(), io::Error> { - // TODO: write padding encode_tlv_stream!(writer, { + (1, Padding::new_from_tlv(ControlTlvs::Receive(*self)), option), (6, self.path_id, option), }); Ok(()) } } + +#[cfg(test)] +mod test { + use bitcoin::secp256k1::{PublicKey, SecretKey, Secp256k1}; + use ::get_control_tlv_length; + use super::{ForwardTlvs, ReceiveTlvs, blinded_hops}; + use util::ser::{VecWriter, Writeable}; + + #[test] + fn padding_is_correctly_serialized() { + let max_length = get_control_tlv_length!(true, true); + + let dummy_next_node_id = PublicKey::from_slice(&hex::decode("030101010101010101010101010101010101010101010101010101010101010101").unwrap()[..]).unwrap(); + let dummy_blinding_override = PublicKey::from_slice(&hex::decode("030202020202020202020202020202020202020202020202020202020202020202").unwrap()[..]).unwrap(); + let dummy_path_id = [1; 32]; + + let no_padding_tlv = ForwardTlvs { + next_node_id: dummy_next_node_id, + next_blinding_override: Some(dummy_blinding_override), + total_length: max_length, + }; + + let blinding_override_padding_tlv = ForwardTlvs { + next_node_id: dummy_next_node_id, + next_blinding_override: None, + total_length: max_length, + }; + + let recieve_tlv_padding_tlv = ReceiveTlvs { + path_id: Some(dummy_path_id), + total_length: max_length, + }; + + let full_padding_tlv = ReceiveTlvs { + path_id: None, + total_length: max_length, + }; + + let mut w = VecWriter(Vec::new()); + no_padding_tlv.write(&mut w).unwrap(); + let serialized_no_padding_tlv = w.0; + // As `serialized_no_padding_tlv` is the longest tlv, no padding is expected. + // Expected data tlv is: + // 1. 4 (type) for `next_node_id` + // 2. 33 (length) for the length of a point/public key + // 3. 33 bytes of the `dummy_next_node_id` + // 4. 8 (type) for `next_blinding_override` + // 5. 33 (length) for the length of a point/public key + // 6. 33 bytes of the `dummy_blinding_override` + let expected_serialized_no_padding_tlv_payload = &hex::decode("04210301010101010101010101010101010101010101010101010101010101010101010821030202020202020202020202020202020202020202020202020202020202020202").unwrap()[..]; + assert_eq!(serialized_no_padding_tlv, expected_serialized_no_padding_tlv_payload); + assert_eq!(serialized_no_padding_tlv.len(), max_length as usize); + + w = VecWriter(Vec::new()); + blinding_override_padding_tlv.write(&mut w).unwrap(); + let serialized_blinding_override_padding_tlv = w.0; + // As `serialized_blinding_override_padding_tlv` has no `next_blinding_override`, 35 bytes + // of padding is expected (the serialized length of `next_blinding_override`). + // Expected data tlv is: + // 1. 1 (type) for padding + // 2. 33 (length) given the length of a the missing `next_blinding_override` + // 3. 33 0 bytes of padding + // 4. 4 (type) for `next_node_id` + // 5. 33 (length) for the length of a point/public key + // 6. 33 bytes of the `dummy_next_node_id` + let expected_serialized_blinding_override_padding_tlv = &hex::decode("01210000000000000000000000000000000000000000000000000000000000000000000421030101010101010101010101010101010101010101010101010101010101010101").unwrap()[..]; + assert_eq!(serialized_blinding_override_padding_tlv, expected_serialized_blinding_override_padding_tlv); + assert_eq!(serialized_blinding_override_padding_tlv.len(), max_length as usize); + + w = VecWriter(Vec::new()); + recieve_tlv_padding_tlv.write(&mut w).unwrap(); + let serialized_recieve_tlv_padding_tlv = w.0; + // As `recieve_tlv_padding_tlv` is a `ReceiveTlv` and has a `path_id`, 36 bytes of padding + // is expected, ie. 70 (value of `max_length`) - 34 (the serialized length of `path_id`). + // Expected data tlv is: + // 1. 1 (type) for padding + // 2. 34 (length) given 70 - 34 + // 3. 34 0 bytes of padding + // 4. 6 (type) for `path_id` + // 5. 32 (length) for the length of a `path_id` + // 6. 32 bytes of the `path_id` + let expected_serialized_recieve_tlv_padding_tlv_payload = &hex::decode("01220000000000000000000000000000000000000000000000000000000000000000000006200101010101010101010101010101010101010101010101010101010101010101").unwrap()[..]; + assert_eq!(serialized_recieve_tlv_padding_tlv, expected_serialized_recieve_tlv_padding_tlv_payload); + assert_eq!(serialized_recieve_tlv_padding_tlv.len(), max_length as usize); + + w = VecWriter(Vec::new()); + full_padding_tlv.write(&mut w).unwrap(); + let serialized_full_padding_tlv = w.0; + // As `serialized_full_padding_tlv` is a `ReceiveTlv` with no data at alll, 70 bytes of + // padding is expected (value of `max_length`). + // Expected data tlv is: + // 1. 1 (type) for padding + // 2. 68 (length) the length of the padding minus the prefix + // 3. 68 0 bytes of padding + let expected_serialized_full_padding_tlv_payload = &hex::decode("01440000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000").unwrap()[..]; + assert_eq!(serialized_full_padding_tlv, expected_serialized_full_padding_tlv_payload); + assert_eq!(serialized_full_padding_tlv.len(), max_length as usize); + } + + #[test] + fn blinded_hops_are_same_length() { + let secp_ctx = Secp256k1::new(); + let first_node_id = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode(format!("{:02}", 41).repeat(32)).unwrap()[..]).unwrap()); + let middle_node_id = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode(format!("{:02}", 42).repeat(32)).unwrap()[..]).unwrap()); + let recieve_node_id = PublicKey::from_secret_key(&secp_ctx, &SecretKey::from_slice(&hex::decode(format!("{:02}", 43).repeat(32)).unwrap()[..]).unwrap()); + let session_priv = SecretKey::from_slice(&hex::decode(format!("{:02}", 3).repeat(32)).unwrap()[..]).unwrap(); + + let blinded_hops = blinded_hops(&secp_ctx, &[first_node_id, middle_node_id, recieve_node_id], &session_priv, false).unwrap(); + + // Verify that the blinded hops returned from `blinded_hops` have the same + // `encrypted_payload` length, regardless of which type of payload it is. + let mut expected_encrypted_payload_len = None; + for blinded_hop in blinded_hops { + match expected_encrypted_payload_len { + None => { + expected_encrypted_payload_len = Some(blinded_hop.encrypted_payload.len()); + }, + Some(expected_len) => { + assert_eq!(blinded_hop.encrypted_payload.len(), expected_len) + } + } + } + } +} diff --git a/lightning/src/onion_message/functional_tests.rs b/lightning/src/onion_message/functional_tests.rs index 22389bf5203..60ec06b351f 100644 --- a/lightning/src/onion_message/functional_tests.rs +++ b/lightning/src/onion_message/functional_tests.rs @@ -13,11 +13,13 @@ use chain::keysinterface::{KeysInterface, Recipient}; use ln::features::InitFeatures; use ln::msgs::{self, OnionMessageHandler}; use super::{BlindedRoute, Destination, OnionMessenger, SendError}; +use super::messenger::packet_payloads_and_keys; +use super::packet::{Payload, ForwardControlTlvs, ReceiveControlTlvs}; use util::enforcing_trait_impls::EnforcingSigner; use util::test_utils; use bitcoin::network::constants::Network; -use bitcoin::secp256k1::{PublicKey, Secp256k1}; +use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey}; use sync::Arc; @@ -98,7 +100,7 @@ fn two_unblinded_two_blinded() { let nodes = create_nodes(5); let secp_ctx = Secp256k1::new(); - let blinded_route = BlindedRoute::new(&[nodes[3].get_node_pk(), nodes[4].get_node_pk()], &*nodes[4].keys_manager, &secp_ctx).unwrap(); + let blinded_route = BlindedRoute::new(&[nodes[3].get_node_pk(), nodes[4].get_node_pk()], &*nodes[4].keys_manager, &secp_ctx, true).unwrap(); nodes[0].messenger.send_onion_message(&[nodes[1].get_node_pk(), nodes[2].get_node_pk()], Destination::BlindedRoute(blinded_route), None).unwrap(); pass_along_path(&nodes, None); @@ -109,7 +111,7 @@ fn three_blinded_hops() { let nodes = create_nodes(4); let secp_ctx = Secp256k1::new(); - let blinded_route = BlindedRoute::new(&[nodes[1].get_node_pk(), nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx).unwrap(); + let blinded_route = BlindedRoute::new(&[nodes[1].get_node_pk(), nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx, true).unwrap(); nodes[0].messenger.send_onion_message(&[], Destination::BlindedRoute(blinded_route), None).unwrap(); pass_along_path(&nodes, None); @@ -133,13 +135,13 @@ fn invalid_blinded_route_error() { // 0 hops let secp_ctx = Secp256k1::new(); - let mut blinded_route = BlindedRoute::new(&[nodes[1].get_node_pk(), nodes[2].get_node_pk()], &*nodes[2].keys_manager, &secp_ctx).unwrap(); + let mut blinded_route = BlindedRoute::new(&[nodes[1].get_node_pk(), nodes[2].get_node_pk()], &*nodes[2].keys_manager, &secp_ctx, true).unwrap(); blinded_route.blinded_hops.clear(); let err = nodes[0].messenger.send_onion_message(&[], Destination::BlindedRoute(blinded_route), None).unwrap_err(); assert_eq!(err, SendError::TooFewBlindedHops); // 1 hop - let mut blinded_route = BlindedRoute::new(&[nodes[1].get_node_pk(), nodes[2].get_node_pk()], &*nodes[2].keys_manager, &secp_ctx).unwrap(); + let mut blinded_route = BlindedRoute::new(&[nodes[1].get_node_pk(), nodes[2].get_node_pk()], &*nodes[2].keys_manager, &secp_ctx, true).unwrap(); blinded_route.blinded_hops.remove(0); assert_eq!(blinded_route.blinded_hops.len(), 1); let err = nodes[0].messenger.send_onion_message(&[], Destination::BlindedRoute(blinded_route), None).unwrap_err(); @@ -152,7 +154,7 @@ fn reply_path() { let secp_ctx = Secp256k1::new(); // Destination::Node - let reply_path = BlindedRoute::new(&[nodes[2].get_node_pk(), nodes[1].get_node_pk(), nodes[0].get_node_pk()], &*nodes[0].keys_manager, &secp_ctx).unwrap(); + let reply_path = BlindedRoute::new(&[nodes[2].get_node_pk(), nodes[1].get_node_pk(), nodes[0].get_node_pk()], &*nodes[0].keys_manager, &secp_ctx, false).unwrap(); nodes[0].messenger.send_onion_message(&[nodes[1].get_node_pk(), nodes[2].get_node_pk()], Destination::Node(nodes[3].get_node_pk()), Some(reply_path)).unwrap(); pass_along_path(&nodes, None); // Make sure the last node successfully decoded the reply path. @@ -161,8 +163,8 @@ fn reply_path() { format!("Received an onion message with path_id: None and reply_path").to_string(), 1); // Destination::BlindedRoute - let blinded_route = BlindedRoute::new(&[nodes[1].get_node_pk(), nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx).unwrap(); - let reply_path = BlindedRoute::new(&[nodes[2].get_node_pk(), nodes[1].get_node_pk(), nodes[0].get_node_pk()], &*nodes[0].keys_manager, &secp_ctx).unwrap(); + let blinded_route = BlindedRoute::new(&[nodes[1].get_node_pk(), nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx, true).unwrap(); + let reply_path = BlindedRoute::new(&[nodes[2].get_node_pk(), nodes[1].get_node_pk(), nodes[0].get_node_pk()], &*nodes[0].keys_manager, &secp_ctx, false).unwrap(); nodes[0].messenger.send_onion_message(&[], Destination::BlindedRoute(blinded_route), Some(reply_path)).unwrap(); pass_along_path(&nodes, None); @@ -180,3 +182,75 @@ fn peer_buffer_full() { let err = nodes[0].messenger.send_onion_message(&[], Destination::Node(nodes[1].get_node_pk()), None).unwrap_err(); assert_eq!(err, SendError::BufferFull); } + +#[test] +fn onion_message_blinded_control_tlv_payloads_are_same_length() { + let nodes = create_nodes(4); + let secp_ctx = Secp256k1::new(); + let two_blinded_hops = BlindedRoute::new(&[nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx, true).unwrap(); + let three_blinded_hops = BlindedRoute::new(&[nodes[1].get_node_pk(), nodes[2].get_node_pk(), nodes[3].get_node_pk()], &*nodes[3].keys_manager, &secp_ctx, true).unwrap(); + let session_priv = SecretKey::from_slice(&hex::decode(format!("{:02}", 3).repeat(32)).unwrap()[..]).unwrap(); + + let only_unblinded_payloads = packet_payloads_and_keys(&secp_ctx, &[nodes[0].get_node_pk(), nodes[1].get_node_pk()], Destination::Node(nodes[2].get_node_pk()), None, &session_priv).unwrap().0; + let one_unblinded_and_three_blinded_payloads = packet_payloads_and_keys(&secp_ctx, &[nodes[1].get_node_pk()], Destination::BlindedRoute(three_blinded_hops), None, &session_priv).unwrap().0; + // When more that one unblinded payload exists, the blinded payloads should be the same length + // as the largest unblinded payload. + let multiple_unblinded_and_blinded_payloads = packet_payloads_and_keys(&secp_ctx, &[nodes[0].get_node_pk(), nodes[1].get_node_pk()], Destination::BlindedRoute(two_blinded_hops), None, &session_priv).unwrap().0; + + // Verify that the blinded contol tlv payloads returned from `packet_payloads_and_keys` have + // the same length, and that the payload for every blinded payload matches the length of the + // largest unblinded payload length. + for payloads in [only_unblinded_payloads, one_unblinded_and_three_blinded_payloads, multiple_unblinded_and_blinded_payloads].iter() { + let mut longest_tlv_length = None; + + macro_rules! assign_longest_tlv_length { + ($unblinded_tlv_length: expr) => { + if longest_tlv_length.map(|current_len| $unblinded_tlv_length > current_len).unwrap_or(true) { + longest_tlv_length = Some($unblinded_tlv_length); + } + }; + } + + macro_rules! assert_correct_tlv_length { + ($tlv_length: expr) => { + match longest_tlv_length { + None => { + longest_tlv_length = Some($tlv_length); + }, + Some(expected_len) => { + assert_eq!($tlv_length, expected_len); + } + } + }; + } + + for payload in payloads { + match &payload.0 { + Payload::Forward(control_tlvs) => { + match control_tlvs { + ForwardControlTlvs::Blinded(bytes) => { + // 16 deducted to account for the 16 byte tag of the ChaCha encryption + // in Blinded ControlTLVs + assert_correct_tlv_length!(bytes.len() as u16 - 16); + }, + ForwardControlTlvs::Unblinded(tlv) => { + assign_longest_tlv_length!(tlv.total_length); + }, + } + }, + Payload::Receive { control_tlvs, .. } => { + match control_tlvs { + ReceiveControlTlvs::Blinded(bytes) => { + // 16 deducted to account for the 16 byte tag of the ChaCha encryption + // in Blinded ControlTLVs + assert_correct_tlv_length!(bytes.len() as u16 - 16); + }, + ReceiveControlTlvs::Unblinded(tlv) => { + assign_longest_tlv_length!(tlv.total_length); + }, + } + }, + }; + } + } +} diff --git a/lightning/src/onion_message/messenger.rs b/lightning/src/onion_message/messenger.rs index e2409fc45d6..db1bb22ce4a 100644 --- a/lightning/src/onion_message/messenger.rs +++ b/lightning/src/onion_message/messenger.rs @@ -22,6 +22,7 @@ use ln::onion_utils; use super::blinded_route::{BlindedRoute, ForwardTlvs, ReceiveTlvs}; use super::packet::{BIG_PACKET_HOP_DATA_LEN, ForwardControlTlvs, Packet, Payload, ReceiveControlTlvs, SMALL_PACKET_HOP_DATA_LEN}; use super::utils; +use ::get_control_tlv_length; use util::events::OnionMessageProvider; use util::logger::Logger; use util::ser::Writeable; @@ -71,7 +72,7 @@ use prelude::*; /// // Create a blinded route to yourself, for someone to send an onion message to. /// # let your_node_id = hop_node_id1; /// let hops = [hop_node_id3, hop_node_id4, your_node_id]; -/// let blinded_route = BlindedRoute::new(&hops, &keys_manager, &secp_ctx).unwrap(); +/// let blinded_route = BlindedRoute::new(&hops, &keys_manager, &secp_ctx, true).unwrap(); /// /// // Send an empty onion message to a blinded route. /// # let intermediate_hops = [hop_node_id1, hop_node_id2]; @@ -256,14 +257,14 @@ impl OnionMessageHandler for OnionMessenger { log_info!(self.logger, "Received an onion message with path_id: {:02x?} and {}reply_path", path_id, if reply_path.is_some() { "" } else { "no " }); }, Ok((Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs { - next_node_id, next_blinding_override + next_node_id, next_blinding_override, .. })), Some((next_hop_hmac, new_packet_bytes)))) => { // TODO: we need to check whether `next_node_id` is our node, in which case this is a dummy // blinded hop and this onion message is destined for us. In this situation, we should keep @@ -395,7 +396,7 @@ pub type SimpleRefOnionMessenger<'a, 'b, L> = OnionMessenger( +pub(super) fn packet_payloads_and_keys( secp_ctx: &Secp256k1, unblinded_path: &[PublicKey], destination: Destination, mut reply_path: Option, session_priv: &SecretKey ) -> Result<(Vec<(Payload, [u8; 32])>, Vec), secp256k1::Error> { @@ -418,6 +419,7 @@ fn packet_payloads_and_keys( ForwardTlvs { next_node_id: unblinded_pk_opt.unwrap(), next_blinding_override: None, + total_length: get_control_tlv_length!(true, false), } )), ss)); } @@ -428,6 +430,7 @@ fn packet_payloads_and_keys( payloads.push((Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs { next_node_id: intro_node_id, next_blinding_override: Some(blinding_pt), + total_length: get_control_tlv_length!(true, true), })), control_tlvs_ss)); } if let Some(encrypted_payload) = enc_payload_opt { @@ -460,7 +463,7 @@ fn packet_payloads_and_keys( if let Some(control_tlvs_ss) = prev_control_tlvs_ss { payloads.push((Payload::Receive { - control_tlvs: ReceiveControlTlvs::Unblinded(ReceiveTlvs { path_id: None, }), + control_tlvs: ReceiveControlTlvs::Unblinded(ReceiveTlvs { path_id: None, total_length: get_control_tlv_length!(false), }), reply_path: reply_path.take(), }, control_tlvs_ss)); } diff --git a/lightning/src/onion_message/packet.rs b/lightning/src/onion_message/packet.rs index 1337bdb14d5..ba3b6a07fe4 100644 --- a/lightning/src/onion_message/packet.rs +++ b/lightning/src/onion_message/packet.rs @@ -14,6 +14,7 @@ use bitcoin::secp256k1::ecdh::SharedSecret; use ln::msgs::DecodeError; use ln::onion_utils; +use ::get_control_tlv_length; use super::blinded_route::{BlindedRoute, ForwardTlvs, ReceiveTlvs}; use util::chacha20poly1305rfc::{ChaChaPolyReadAdapter, ChaChaPolyWriteAdapter}; use util::ser::{BigSize, FixedLengthReader, LengthRead, LengthReadable, LengthReadableArgs, Readable, ReadableArgs, Writeable, Writer}; @@ -197,7 +198,7 @@ impl ReadableArgs for Payload { /// When reading a packet off the wire, we don't know a priori whether the packet is to be forwarded /// or received. Thus we read a ControlTlvs rather than reading a ForwardControlTlvs or /// ReceiveControlTlvs directly. -pub(super) enum ControlTlvs { +pub(crate) enum ControlTlvs { /// This onion message is intended to be forwarded. Forward(ForwardTlvs), /// This onion message is intended to be received. @@ -206,13 +207,13 @@ pub(super) enum ControlTlvs { impl Readable for ControlTlvs { fn read(mut r: &mut R) -> Result { - let mut _padding: Option = None; + let mut padding: Option = None; let mut _short_channel_id: Option = None; let mut next_node_id: Option = None; let mut path_id: Option<[u8; 32]> = None; let mut next_blinding_override: Option = None; decode_tlv_stream!(&mut r, { - (1, _padding, option), + (1, padding, option), (2, _short_channel_id, option), (4, next_node_id, option), (6, path_id, option), @@ -222,13 +223,20 @@ impl Readable for ControlTlvs { let valid_fwd_fmt = next_node_id.is_some() && path_id.is_none(); let valid_recv_fmt = next_node_id.is_none() && next_blinding_override.is_none(); + let mut total_length = get_control_tlv_length!(next_node_id.is_some(), next_blinding_override.is_some(), path_id.is_some(), 0, 0); + if padding.is_some() { + total_length += padding.unwrap().padding_length + 2 // 2 extra prefix bytes + } + let payload_fmt = if valid_fwd_fmt { ControlTlvs::Forward(ForwardTlvs { + total_length, next_node_id: next_node_id.unwrap(), next_blinding_override, }) } else if valid_recv_fmt { ControlTlvs::Receive(ReceiveTlvs { + total_length, path_id, }) } else { @@ -239,15 +247,57 @@ impl Readable for ControlTlvs { } } -/// Reads padding to the end, ignoring what's read. -pub(crate) struct Padding {} +pub(crate) struct Padding { + padding_length: u16, +} + +/// Reads padding to the end, and validates that the read bytes' content are 0s. impl Readable for Padding { #[inline] fn read(reader: &mut R) -> Result { + let mut padding_length = 0; loop { let mut buf = [0; 8192]; - if reader.read(&mut buf[..])? == 0 { break; } + let read_bytes_len = reader.read(&mut buf[..])?; + if read_bytes_len == 0 { break; } + for n in 0..read_bytes_len { + if buf[n] != 0 as u8 { return Err(DecodeError::InvalidValue); } + } + padding_length += read_bytes_len as u16; + } + Ok(Self { padding_length }) + } +} + +impl Writeable for Padding { + #[inline] + fn write(&self, writer: &mut W) -> Result<(), io::Error> { + for _ in 0..self.padding_length { + (0 as u8).write(writer)?; + } + Ok(()) + } +} + +impl Padding { + pub fn new_from_tlv(tlv: ControlTlvs) -> Option { + let (data_length, tlv_total_length) = match tlv { + ControlTlvs::Forward(forward_tlvs) => { + let data_length = get_control_tlv_length!(true, forward_tlvs.next_blinding_override.is_some()); + (data_length, forward_tlvs.total_length as u16) + }, + ControlTlvs::Receive(receive_tlvs) => { + let data_length = get_control_tlv_length!(receive_tlvs.path_id.is_some()); + (data_length, receive_tlvs.total_length as u16) + } + }; + + let extra_bytes_needed = tlv_total_length - data_length; + if extra_bytes_needed >= 2 { + let padding_length = extra_bytes_needed - 2; // 2 bytes of prefix removed + Some(Padding { padding_length }) + } else { + None } - Ok(Self {}) } } diff --git a/lightning/src/onion_message/utils.rs b/lightning/src/onion_message/utils.rs index 52cadf6c9db..ea3fd5f795d 100644 --- a/lightning/src/onion_message/utils.rs +++ b/lightning/src/onion_message/utils.rs @@ -96,3 +96,59 @@ pub(super) fn construct_keys_callback {{ // ReceiveControlTlvs + get_control_tlv_length!(false, false, $has_path_id, 0, 0) + }}; + ($has_next_node_id: expr, $has_next_blinding_override: expr) => {{ // ForwardControlTlvs + get_control_tlv_length!($has_next_node_id, $has_next_blinding_override, false, 0, 0) + }}; + ($has_next_node_id: expr, $has_next_blinding_override: expr, $has_path_id: expr, $tag_prefix_length: expr, $tag_length: expr) => {{ + // tag_prefix_length and tag_length refer to custom types in ControlTlvs, not the be + // confused with the onion message tag. + let mut res = 0; + + macro_rules! add_length { + ($should_add_len: expr, $prefix_len: expr, $content_len: expr) => { + if $should_add_len { + res += $prefix_len; + res += $content_len; + } + } + } + + add_length!($has_next_node_id, 2, 33); + add_length!($has_next_blinding_override, 2, 33); + add_length!($has_path_id, 2, 32); + add_length!($tag_length > 0, $tag_prefix_length, $tag_length); + + res + }} + + /* + TODO: + + Also add support for payment_onion ControlTlvs also consisting of: + + payment_relay: + 2 bytes prefix + 2 bytes for cltv_expiry_delta + 4 bytes for fee_proportional_millionths + 0-4 bytes for fee_base_msat (tu32) + + payment_constraints: + 2 bytes prefix + 4 bytes max_cltv_expiry + 0-8 bytes htlc_minimum_msat (tu64) + + allowed_features: + - If IS payment onion AND has NO known allowed_features: + 2 bytes prefix only + - If IS payment onion AND HAS known allowed_features: + 2 bytes prefix + X bytes of allowed_features + */ +} \ No newline at end of file