diff --git a/lightning/src/ln/msgs.rs b/lightning/src/ln/msgs.rs index 0e5b2e07e7a..d40de9f705f 100644 --- a/lightning/src/ln/msgs.rs +++ b/lightning/src/ln/msgs.rs @@ -790,7 +790,7 @@ pub struct CommitmentUpdate { /// Messages could have optional fields to use with extended features /// As we wish to serialize these differently from Options (Options get a tag byte, but -/// OptionalFeild simply gets Present if there are enough bytes to read into it), we have a +/// OptionalField simply gets Present if there are enough bytes to read into it), we have a /// separate enum type for them. /// (C-not exported) due to a free generic in T #[derive(Clone, Debug, PartialEq)] diff --git a/lightning/src/util/mod.rs b/lightning/src/util/mod.rs index c6181ab269a..301ba6afe72 100644 --- a/lightning/src/util/mod.rs +++ b/lightning/src/util/mod.rs @@ -13,7 +13,7 @@ pub(crate) mod fuzz_wrappers; #[macro_use] -pub(crate) mod ser_macros; +pub mod ser_macros; pub mod events; pub mod errors; diff --git a/lightning/src/util/ser.rs b/lightning/src/util/ser.rs index 428adbc5e66..6b519b9eb1f 100644 --- a/lightning/src/util/ser.rs +++ b/lightning/src/util/ser.rs @@ -91,21 +91,24 @@ impl Writer for LengthCalculatingWriter { /// Essentially std::io::Take but a bit simpler and with a method to walk the underlying stream /// forward to ensure we always consume exactly the fixed length specified. -pub(crate) struct FixedLengthReader { +pub struct FixedLengthReader { read: R, bytes_read: u64, total_bytes: u64, } impl FixedLengthReader { + /// Returns a new FixedLengthReader. pub fn new(read: R, total_bytes: u64) -> Self { Self { read, bytes_read: 0, total_bytes } } + /// Returns whether there are remaining bytes or not. #[inline] pub fn bytes_remain(&mut self) -> bool { self.bytes_read != self.total_bytes } + /// Consume the remaning bytes. #[inline] pub fn eat_remaining(&mut self) -> Result<(), DecodeError> { copy(self, &mut sink()).unwrap(); @@ -136,11 +139,13 @@ impl Read for FixedLengthReader { /// A Read which tracks whether any bytes have been read at all. This allows us to distinguish /// between "EOF reached before we started" and "EOF reached mid-read". -pub(crate) struct ReadTrackingReader { +pub struct ReadTrackingReader { read: R, + /// Tells whether we have read from this reader or not yet. pub have_read: bool, } impl ReadTrackingReader { + /// Returns a new ReadTrackingReader. pub fn new(read: R) -> Self { Self { read, have_read: false } } @@ -237,7 +242,8 @@ impl MaybeReadable for T { } } -pub(crate) struct OptionDeserWrapper(pub Option); +/// Wrapper to read a required (non-optional) TLV record. +pub struct OptionDeserWrapper(pub Option); impl Readable for OptionDeserWrapper { #[inline] fn read(reader: &mut R) -> Result { @@ -246,7 +252,7 @@ impl Readable for OptionDeserWrapper { } /// Wrapper to write each element of a Vec with no length prefix -pub(crate) struct VecWriteWrapper<'a, T: Writeable>(pub &'a Vec); +pub struct VecWriteWrapper<'a, T: Writeable>(pub &'a Vec); impl<'a, T: Writeable> Writeable for VecWriteWrapper<'a, T> { #[inline] fn write(&self, writer: &mut W) -> Result<(), io::Error> { @@ -258,7 +264,7 @@ impl<'a, T: Writeable> Writeable for VecWriteWrapper<'a, T> { } /// Wrapper to read elements from a given stream until it reaches the end of the stream. -pub(crate) struct VecReadWrapper(pub Vec); +pub struct VecReadWrapper(pub Vec); impl Readable for VecReadWrapper { #[inline] fn read(mut reader: &mut R) -> Result { diff --git a/lightning/src/util/ser_macros.rs b/lightning/src/util/ser_macros.rs index 165d1f1edba..7fda9d20ed3 100644 --- a/lightning/src/util/ser_macros.rs +++ b/lightning/src/util/ser_macros.rs @@ -7,9 +7,14 @@ // You may not use this file except in accordance with one or both of these // licenses. +//! Some macros that implement Readable/Writeable traits for lightning messages. +//! They also handle serialization and deserialization of TLVs. + +/// Implements serialization for a single TLV record. +#[macro_export] macro_rules! encode_tlv { ($stream: expr, $type: expr, $field: expr, (default_value, $default: expr)) => { - encode_tlv!($stream, $type, $field, required) + $crate::encode_tlv!($stream, $type, $field, required) }; ($stream: expr, $type: expr, $field: expr, required) => { BigSize($type).write($stream)?; @@ -17,7 +22,7 @@ macro_rules! encode_tlv { $field.write($stream)?; }; ($stream: expr, $type: expr, $field: expr, vec_type) => { - encode_tlv!($stream, $type, ::util::ser::VecWriteWrapper(&$field), required); + $crate::encode_tlv!($stream, $type, ser::VecWriteWrapper(&$field), required); }; ($stream: expr, $optional_type: expr, $optional_field: expr, option) => { if let Some(ref field) = $optional_field { @@ -28,17 +33,19 @@ macro_rules! encode_tlv { }; } +/// Implements the TLVs serialization part in a Writeable implementation of a struct. +#[macro_export] macro_rules! encode_tlv_stream { ($stream: expr, {$(($type: expr, $field: expr, $fieldty: tt)),* $(,)*}) => { { #[allow(unused_imports)] use { - ln::msgs::DecodeError, - util::ser, - util::ser::BigSize, + $crate::ln::msgs::DecodeError, + $crate::util::ser, + $crate::util::ser::BigSize, }; $( - encode_tlv!($stream, $type, $field, $fieldty); + $crate::encode_tlv!($stream, $type, $field, $fieldty); )* #[allow(unused_mut, unused_variables, unused_assignments)] @@ -47,7 +54,8 @@ macro_rules! encode_tlv_stream { let mut last_seen: Option = None; $( if let Some(t) = last_seen { - debug_assert!(t <= $type); + #[allow(unused_comparisons)] // Note that $type may be 0 making the following comparison always false + (debug_assert!($type > t)) } last_seen = Some($type); )* @@ -66,7 +74,7 @@ macro_rules! get_varint_length_prefixed_tlv_length { $len.0 += field_len; }; ($len: expr, $type: expr, $field: expr, vec_type) => { - get_varint_length_prefixed_tlv_length!($len, $type, ::util::ser::VecWriteWrapper(&$field), required); + get_varint_length_prefixed_tlv_length!($len, $type, $crate::util::ser::VecWriteWrapper(&$field), required); }; ($len: expr, $optional_type: expr, $optional_field: expr, option) => { if let Some(ref field) = $optional_field { @@ -80,10 +88,10 @@ macro_rules! get_varint_length_prefixed_tlv_length { macro_rules! encode_varint_length_prefixed_tlv { ($stream: expr, {$(($type: expr, $field: expr, $fieldty: tt)),*}) => { { - use util::ser::BigSize; + use $crate::util::ser::BigSize; let len = { #[allow(unused_mut)] - let mut len = ::util::ser::LengthCalculatingWriter(0); + let mut len = $crate::util::ser::LengthCalculatingWriter(0); $( get_varint_length_prefixed_tlv_length!(len, $type, $field, $fieldty); )* @@ -94,16 +102,18 @@ macro_rules! encode_varint_length_prefixed_tlv { } } } +/// Errors if there are missing required TLV types between the last seen type and the type currently being processed. +#[macro_export] macro_rules! check_tlv_order { ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, (default_value, $default: expr)) => {{ - #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true + #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always false let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type; if invalid_order { $field = $default; } }}; ($last_seen_type: expr, $typ: expr, $type: expr, $field: ident, required) => {{ - #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true + #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always false let invalid_order = ($last_seen_type.is_none() || $last_seen_type.unwrap() < $type) && $typ.0 > $type; if invalid_order { return Err(DecodeError::InvalidValue); @@ -120,16 +130,18 @@ macro_rules! check_tlv_order { }}; } +/// Errors if there are missing required TLV types after the last seen type. +#[macro_export] macro_rules! check_missing_tlv { ($last_seen_type: expr, $type: expr, $field: ident, (default_value, $default: expr)) => {{ - #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true + #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always false let missing_req_type = $last_seen_type.is_none() || $last_seen_type.unwrap() < $type; if missing_req_type { $field = $default; } }}; ($last_seen_type: expr, $type: expr, $field: ident, required) => {{ - #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always true + #[allow(unused_comparisons)] // Note that $type may be 0 making the second comparison always false let missing_req_type = $last_seen_type.is_none() || $last_seen_type.unwrap() < $type; if missing_req_type { return Err(DecodeError::InvalidValue); @@ -146,15 +158,17 @@ macro_rules! check_missing_tlv { }}; } +/// Implements deserialization for a single TLV record. +#[macro_export] macro_rules! decode_tlv { ($reader: expr, $field: ident, (default_value, $default: expr)) => {{ - decode_tlv!($reader, $field, required) + $crate::decode_tlv!($reader, $field, required) }}; ($reader: expr, $field: ident, required) => {{ $field = ser::Readable::read(&mut $reader)?; }}; ($reader: expr, $field: ident, vec_type) => {{ - let f: ::util::ser::VecReadWrapper<_> = ser::Readable::read(&mut $reader)?; + let f: $crate::util::ser::VecReadWrapper<_> = ser::Readable::read(&mut $reader)?; $field = Some(f.0); }}; ($reader: expr, $field: ident, option) => {{ @@ -165,20 +179,22 @@ macro_rules! decode_tlv { }}; } +/// Implements the TLVs deserialization part in a Readable implementation of a struct. +#[macro_export] macro_rules! decode_tlv_stream { ($stream: expr, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => { { - use ln::msgs::DecodeError; + use $crate::ln::msgs::DecodeError; let mut last_seen_type: Option = None; let mut stream_ref = $stream; 'tlv_read: loop { - use util::ser; + use $crate::util::ser; // First decode the type of this TLV: let typ: ser::BigSize = { // We track whether any bytes were read during the consensus_decode call to // determine whether we should break or return ShortRead if we get an // UnexpectedEof. This should in every case be largely cosmetic, but its nice to - // pass the TLV test vectors exactly, which requre this distinction. + // pass the TLV test vectors exactly, which require this distinction. let mut tracking_reader = ser::ReadTrackingReader::new(&mut stream_ref); match ser::Readable::read(&mut tracking_reader) { Err(DecodeError::ShortRead) => { @@ -200,9 +216,9 @@ macro_rules! decode_tlv_stream { }, _ => {}, } - // As we read types, make sure we hit every required type: + // As we read types, make sure we hit every required type between last_seen_type and typ: $({ - check_tlv_order!(last_seen_type, typ, $type, $field, $fieldty); + $crate::check_tlv_order!(last_seen_type, typ, $type, $field, $fieldty); })* last_seen_type = Some(typ.0); @@ -211,7 +227,7 @@ macro_rules! decode_tlv_stream { let mut s = ser::FixedLengthReader::new(&mut stream_ref, length.0); match typ.0 { $($type => { - decode_tlv!(s, $field, $fieldty); + $crate::decode_tlv!(s, $field, $fieldty); if s.bytes_remain() { s.eat_remaining()?; // Return ShortRead if there's actually not enough bytes return Err(DecodeError::InvalidValue); @@ -226,25 +242,48 @@ macro_rules! decode_tlv_stream { } // Make sure we got to each required type after we've read every TLV: $({ - check_missing_tlv!(last_seen_type, $type, $field, $fieldty); + $crate::check_missing_tlv!(last_seen_type, $type, $field, $fieldty); })* } } } +/// Implements Readable/Writeable for a struct. This macro also handles (de)serialization of TLV records. +/// # Example +/// ``` +/// #[derive(Debug)] +/// pub struct LightningMessage { +/// pub to: String, +/// pub note: String, +/// pub secret_number: u64, +/// // TLV records +/// pub nick_name: Option, +/// pub street_number: Option, +/// } +/// +/// lightning::impl_writeable_msg!(LightningMessage, { +/// to, +/// note, +/// secret_number, +/// }, { +/// (1, nick_name, option), +/// (3, street_number, option), +/// }); +/// ``` +#[macro_export] macro_rules! impl_writeable_msg { ($st:ident, {$($field:ident),* $(,)*}, {$(($type: expr, $tlvfield: ident, $fieldty: tt)),* $(,)*}) => { - impl ::util::ser::Writeable for $st { - fn write(&self, w: &mut W) -> Result<(), $crate::io::Error> { + impl $crate::util::ser::Writeable for $st { + fn write(&self, w: &mut W) -> Result<(), $crate::io::Error> { $( self.$field.write(w)?; )* - encode_tlv_stream!(w, {$(($type, self.$tlvfield, $fieldty)),*}); + $crate::encode_tlv_stream!(w, {$(($type, self.$tlvfield, $fieldty)),*}); Ok(()) } } - impl ::util::ser::Readable for $st { - fn read(r: &mut R) -> Result { - $(let $field = ::util::ser::Readable::read(r)?;)* - $(init_tlv_field_var!($tlvfield, $fieldty);)* - decode_tlv_stream!(r, {$(($type, $tlvfield, $fieldty)),*}); + impl $crate::util::ser::Readable for $st { + fn read(r: &mut R) -> Result { + $(let $field = $crate::util::ser::Readable::read(r)?;)* + $($crate::init_tlv_field_var!($tlvfield, $fieldty);)* + $crate::decode_tlv_stream!(r, {$(($type, $tlvfield, $fieldty)),*}); Ok(Self { $($field),*, $($tlvfield),* @@ -254,10 +293,13 @@ macro_rules! impl_writeable_msg { } } +/// Implements Readable/Writeable for a struct. Note that this macro doesn't handle messages +/// containing TLV records. If your message contains TLVs, use `impl_writeable_msg!` instead. +#[macro_export] macro_rules! impl_writeable { ($st:ident, {$($field:ident),*}) => { - impl ::util::ser::Writeable for $st { - fn write(&self, w: &mut W) -> Result<(), $crate::io::Error> { + impl $crate::util::ser::Writeable for $st { + fn write(&self, w: &mut W) -> Result<(), $crate::io::Error> { $( self.$field.write(w)?; )* Ok(()) } @@ -270,10 +312,10 @@ macro_rules! impl_writeable { } } - impl ::util::ser::Readable for $st { - fn read(r: &mut R) -> Result { + impl $crate::util::ser::Readable for $st { + fn read(r: &mut R) -> Result { Ok(Self { - $($field: ::util::ser::Readable::read(r)?),* + $($field: $crate::util::ser::Readable::read(r)?),* }) } } @@ -331,10 +373,10 @@ macro_rules! read_ver_prefix { /// Reads a suffix added by write_tlv_fields. macro_rules! read_tlv_fields { ($stream: expr, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => { { - let tlv_len: ::util::ser::BigSize = ::util::ser::Readable::read($stream)?; - let mut rd = ::util::ser::FixedLengthReader::new($stream, tlv_len.0); + let tlv_len: $crate::util::ser::BigSize = $crate::util::ser::Readable::read($stream)?; + let mut rd = $crate::util::ser::FixedLengthReader::new($stream, tlv_len.0); decode_tlv_stream!(&mut rd, {$(($type, $field, $fieldty)),*}); - rd.eat_remaining().map_err(|_| ::ln::msgs::DecodeError::ShortRead)?; + rd.eat_remaining().map_err(|_| $crate::ln::msgs::DecodeError::ShortRead)?; } } } @@ -353,12 +395,14 @@ macro_rules! init_tlv_based_struct_field { }; } +/// Initializes the variables we are going to read the TLVs into. +#[macro_export] macro_rules! init_tlv_field_var { ($field: ident, (default_value, $default: expr)) => { let mut $field = $default; }; ($field: ident, required) => { - let mut $field = ::util::ser::OptionDeserWrapper(None); + let mut $field = $crate::util::ser::OptionDeserWrapper(None); }; ($field: ident, vec_type) => { let mut $field = Some(Vec::new()); @@ -375,8 +419,8 @@ macro_rules! init_tlv_field_var { /// serialized. macro_rules! impl_writeable_tlv_based { ($st: ident, {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*}) => { - impl ::util::ser::Writeable for $st { - fn write(&self, writer: &mut W) -> Result<(), $crate::io::Error> { + impl $crate::util::ser::Writeable for $st { + fn write(&self, writer: &mut W) -> Result<(), $crate::io::Error> { write_tlv_fields!(writer, { $(($type, self.$field, $fieldty)),* }); @@ -385,23 +429,23 @@ macro_rules! impl_writeable_tlv_based { #[inline] fn serialized_length(&self) -> usize { - use util::ser::BigSize; + use $crate::util::ser::BigSize; let len = { #[allow(unused_mut)] - let mut len = ::util::ser::LengthCalculatingWriter(0); + let mut len = $crate::util::ser::LengthCalculatingWriter(0); $( get_varint_length_prefixed_tlv_length!(len, $type, self.$field, $fieldty); )* len.0 }; - let mut len_calc = ::util::ser::LengthCalculatingWriter(0); + let mut len_calc = $crate::util::ser::LengthCalculatingWriter(0); BigSize(len as u64).write(&mut len_calc).expect("No in-memory data may fail to serialize"); len + len_calc.0 } } - impl ::util::ser::Readable for $st { - fn read(reader: &mut R) -> Result { + impl $crate::util::ser::Readable for $st { + fn read(reader: &mut R) -> Result { $( init_tlv_field_var!($field, $fieldty); )* @@ -423,8 +467,8 @@ macro_rules! _impl_writeable_tlv_based_enum_common { {$(($type: expr, $field: ident, $fieldty: tt)),* $(,)*} ),* $(,)*; $(($tuple_variant_id: expr, $tuple_variant_name: ident)),* $(,)*) => { - impl ::util::ser::Writeable for $st { - fn write(&self, writer: &mut W) -> Result<(), $crate::io::Error> { + impl $crate::util::ser::Writeable for $st { + fn write(&self, writer: &mut W) -> Result<(), $crate::io::Error> { match self { $($st::$variant_name { $(ref $field),* } => { let id: u8 = $variant_id; @@ -462,9 +506,9 @@ macro_rules! impl_writeable_tlv_based_enum_upgradable { $(($variant_id, $variant_name) => {$(($type, $field, $fieldty)),*}),*; $($(($tuple_variant_id, $tuple_variant_name)),*)*); - impl ::util::ser::MaybeReadable for $st { - fn read(reader: &mut R) -> Result, ::ln::msgs::DecodeError> { - let id: u8 = ::util::ser::Readable::read(reader)?; + impl $crate::util::ser::MaybeReadable for $st { + fn read(reader: &mut R) -> Result, $crate::ln::msgs::DecodeError> { + let id: u8 = $crate::util::ser::Readable::read(reader)?; match id { $($variant_id => { // Because read_tlv_fields creates a labeled loop, we cannot call it twice @@ -515,9 +559,9 @@ macro_rules! impl_writeable_tlv_based_enum { $(($variant_id, $variant_name) => {$(($type, $field, $fieldty)),*}),*; $(($tuple_variant_id, $tuple_variant_name)),*); - impl ::util::ser::Readable for $st { - fn read(reader: &mut R) -> Result { - let id: u8 = ::util::ser::Readable::read(reader)?; + impl $crate::util::ser::Readable for $st { + fn read(reader: &mut R) -> Result { + let id: u8 = $crate::util::ser::Readable::read(reader)?; match id { $($variant_id => { // Because read_tlv_fields creates a labeled loop, we cannot call it twice