diff --git a/lightning-block-sync/src/lib.rs b/lightning-block-sync/src/lib.rs index db536fe7d43..89d95a26b3c 100644 --- a/lightning-block-sync/src/lib.rs +++ b/lightning-block-sync/src/lib.rs @@ -235,7 +235,7 @@ impl<'a, P: Poll, C: Cache, L: Deref> SpvClient<'a, P, C, L> where L::Target: ch /// Returns the best polled chain tip relative to the previous best known tip and whether any /// blocks were indeed connected or disconnected. pub async fn poll_best_tip(&mut self) -> BlockSourceResult<(ChainTip, bool)> { - let chain_tip = self.chain_poller.poll_chain_tip(self.chain_tip).await?; + let chain_tip = self.chain_poller.poll_chain_tip(Some(self.chain_tip)).await?; let blocks_connected = match chain_tip { ChainTip::Common => false, ChainTip::Better(chain_tip) => { diff --git a/lightning-block-sync/src/poll.rs b/lightning-block-sync/src/poll.rs index 34be2437c8e..a6e26c57af7 100644 --- a/lightning-block-sync/src/poll.rs +++ b/lightning-block-sync/src/poll.rs @@ -15,7 +15,7 @@ use std::ops::DerefMut; /// [`ChainPoller`]: ../struct.ChainPoller.html pub trait Poll { /// Returns a chain tip in terms of its relationship to the provided chain tip. - fn poll_chain_tip<'a>(&'a mut self, best_known_chain_tip: ValidatedBlockHeader) -> + fn poll_chain_tip<'a>(&'a mut self, best_known_chain_tip: Option) -> AsyncBlockSourceResult<'a, ChainTip>; /// Returns the header that preceded the given header in the chain. @@ -174,22 +174,28 @@ impl + Sized + Sync + Send, T: BlockSource> ChainPoller + Sized + Sync + Send, T: BlockSource> Poll for ChainPoller { - fn poll_chain_tip<'a>(&'a mut self, best_known_chain_tip: ValidatedBlockHeader) -> + fn poll_chain_tip<'a>(&'a mut self, best_known_chain_tip: Option) -> AsyncBlockSourceResult<'a, ChainTip> { Box::pin(async move { let (block_hash, height) = self.block_source.get_best_block().await?; - if block_hash == best_known_chain_tip.header.block_hash() { - return Ok(ChainTip::Common); + if let Some(curr_tip) = best_known_chain_tip { + if block_hash == curr_tip.header.block_hash() { + return Ok(ChainTip::Common); + } } let chain_tip = self.block_source .get_header(&block_hash, height).await? .validate(block_hash)?; - if chain_tip.chainwork > best_known_chain_tip.chainwork { - Ok(ChainTip::Better(chain_tip)) + if let Some(curr_tip) = best_known_chain_tip { + if chain_tip.chainwork > curr_tip.chainwork { + Ok(ChainTip::Better(chain_tip)) + } else { + Ok(ChainTip::Worse(chain_tip)) + } } else { - Ok(ChainTip::Worse(chain_tip)) + Ok(ChainTip::Better(chain_tip)) } }) } @@ -234,7 +240,7 @@ mod tests { #[tokio::test] async fn poll_empty_chain() { let mut chain = Blockchain::default().with_height(0); - let best_known_chain_tip = chain.tip(); + let best_known_chain_tip = Some(chain.tip()); chain.disconnect_tip(); let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); @@ -250,7 +256,7 @@ mod tests { #[tokio::test] async fn poll_chain_without_headers() { let mut chain = Blockchain::default().with_height(1).without_headers(); - let best_known_chain_tip = chain.at_height(0); + let best_known_chain_tip = Some(chain.at_height(0)); let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); match poller.poll_chain_tip(best_known_chain_tip).await { @@ -265,7 +271,7 @@ mod tests { #[tokio::test] async fn poll_chain_with_invalid_pow() { let mut chain = Blockchain::default().with_height(1); - let best_known_chain_tip = chain.at_height(0); + let best_known_chain_tip = Some(chain.at_height(0)); // Invalidate the tip by changing its target. chain.blocks.last_mut().unwrap().header.bits = @@ -284,7 +290,7 @@ mod tests { #[tokio::test] async fn poll_chain_with_malformed_headers() { let mut chain = Blockchain::default().with_height(1).malformed_headers(); - let best_known_chain_tip = chain.at_height(0); + let best_known_chain_tip = Some(chain.at_height(0)); let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); match poller.poll_chain_tip(best_known_chain_tip).await { @@ -299,7 +305,7 @@ mod tests { #[tokio::test] async fn poll_chain_with_common_tip() { let mut chain = Blockchain::default().with_height(0); - let best_known_chain_tip = chain.tip(); + let best_known_chain_tip = Some(chain.tip()); let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); match poller.poll_chain_tip(best_known_chain_tip).await { @@ -319,7 +325,7 @@ mod tests { assert_eq!(best_known_chain_tip.chainwork, worse_chain_tip.chainwork); let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); - match poller.poll_chain_tip(best_known_chain_tip).await { + match poller.poll_chain_tip(Some(best_known_chain_tip)).await { Err(e) => panic!("Unexpected error: {:?}", e), Ok(tip) => assert_eq!(tip, ChainTip::Worse(worse_chain_tip)), } @@ -328,7 +334,7 @@ mod tests { #[tokio::test] async fn poll_chain_with_worse_tip() { let mut chain = Blockchain::default().with_height(1); - let best_known_chain_tip = chain.tip(); + let best_known_chain_tip = Some(chain.tip()); chain.disconnect_tip(); let worse_chain_tip = chain.tip(); @@ -348,9 +354,20 @@ mod tests { let better_chain_tip = chain.tip(); let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); - match poller.poll_chain_tip(best_known_chain_tip).await { + match poller.poll_chain_tip(Some(best_known_chain_tip)).await { Err(e) => panic!("Unexpected error: {:?}", e), Ok(tip) => assert_eq!(tip, ChainTip::Better(better_chain_tip)), } } + + #[tokio::test] + async fn poll_chain_no_best_tip() { + let mut chain = Blockchain::default().with_height(1); + let chain_tip = chain.tip(); + let mut poller = ChainPoller::new(&mut chain, Network::Bitcoin); + match poller.poll_chain_tip(None).await { + Err(e) => panic!("Unexpected error: {:?}", e), + Ok(tip) => assert_eq!(tip, ChainTip::Better(chain_tip)), + } + } }