diff --git a/crates/chain-state/src/in_memory.rs b/crates/chain-state/src/in_memory.rs index 0971a939dc04..140d78d75bbf 100644 --- a/crates/chain-state/src/in_memory.rs +++ b/crates/chain-state/src/in_memory.rs @@ -9,7 +9,7 @@ use reth_chainspec::ChainInfo; use reth_execution_types::{Chain, ExecutionOutcome}; use reth_primitives::{ Address, BlockNumHash, Header, Receipt, Receipts, SealedBlock, SealedBlockWithSenders, - SealedHeader, B256, + SealedHeader, TransactionMeta, TransactionSigned, TxHash, B256, }; use reth_storage_api::StateProviderBox; use reth_trie::{updates::TrieUpdates, HashedPostState}; @@ -417,6 +417,64 @@ impl CanonicalInMemoryState { MemoryOverlayStateProvider::new(in_memory, historical) } + + /// Returns an iterator over all canonical blocks in the in-memory state, from newest to oldest. + pub fn canonical_chain(&self) -> impl Iterator> { + let pending = self.inner.in_memory_state.pending.read().clone(); + let head = self.inner.in_memory_state.head_state(); + + // this clone is cheap because we only expect to keep in memory a few + // blocks and all of them are Arcs. + let blocks = self.inner.in_memory_state.blocks.read().clone(); + + std::iter::once(pending).filter_map(|p| p.map(Arc::new)).chain(std::iter::successors( + head, + move |state| { + let parent_hash = state.block().block().parent_hash; + blocks.get(&parent_hash).cloned() + }, + )) + } + + /// Returns a `TransactionSigned` for the given `TxHash` if found. + pub fn transaction_by_hash(&self, hash: TxHash) -> Option { + for block_state in self.canonical_chain() { + if let Some(tx) = block_state.block().block().body.iter().find(|tx| tx.hash() == hash) { + return Some(tx.clone()) + } + } + None + } + + /// Returns a tuple with `TransactionSigned` and `TransactionMeta` for the + /// given `TxHash` if found. + pub fn transaction_by_hash_with_meta( + &self, + tx_hash: TxHash, + ) -> Option<(TransactionSigned, TransactionMeta)> { + for (block_number, block_state) in self.canonical_chain().enumerate() { + if let Some((index, tx)) = block_state + .block() + .block() + .body + .iter() + .enumerate() + .find(|(_, tx)| tx.hash() == tx_hash) + { + let meta = TransactionMeta { + tx_hash, + index: index as u64, + block_hash: block_state.hash(), + block_number: block_number as u64, + base_fee: block_state.block().block().header.base_fee_per_gas, + timestamp: block_state.block().block.timestamp, + excess_blob_gas: block_state.block().block.excess_blob_gas, + }; + return Some((tx.clone(), meta)) + } + } + None + } } /// State after applying the given block, this block is part of the canonical chain that partially @@ -986,6 +1044,81 @@ mod tests { assert_eq!(empty_overlay_provider.in_memory.len(), 0); } + #[test] + fn test_canonical_in_memory_state_canonical_chain_empty() { + let state = CanonicalInMemoryState::new(HashMap::new(), HashMap::new(), None, None); + let chain: Vec<_> = state.canonical_chain().collect(); + assert!(chain.is_empty()); + } + + #[test] + fn test_canonical_in_memory_state_canonical_chain_single_block() { + let block = TestBlockBuilder::default().get_executed_block_with_number(1, B256::random()); + let hash = block.block().hash(); + let mut blocks = HashMap::new(); + blocks.insert(hash, Arc::new(BlockState::new(block))); + let mut numbers = HashMap::new(); + numbers.insert(1, hash); + + let state = CanonicalInMemoryState::new(blocks, numbers, None, None); + let chain: Vec<_> = state.canonical_chain().collect(); + + assert_eq!(chain.len(), 1); + assert_eq!(chain[0].number(), 1); + assert_eq!(chain[0].hash(), hash); + } + + #[test] + fn test_canonical_in_memory_state_canonical_chain_multiple_blocks() { + let mut blocks = HashMap::new(); + let mut numbers = HashMap::new(); + let mut parent_hash = B256::random(); + let mut block_builder = TestBlockBuilder::default(); + + for i in 1..=3 { + let block = block_builder.get_executed_block_with_number(i, parent_hash); + let hash = block.block().hash(); + blocks.insert(hash, Arc::new(BlockState::new(block.clone()))); + numbers.insert(i, hash); + parent_hash = hash; + } + + let state = CanonicalInMemoryState::new(blocks, numbers, None, None); + let chain: Vec<_> = state.canonical_chain().collect(); + + assert_eq!(chain.len(), 3); + assert_eq!(chain[0].number(), 3); + assert_eq!(chain[1].number(), 2); + assert_eq!(chain[2].number(), 1); + } + + #[test] + fn test_canonical_in_memory_state_canonical_chain_with_pending_block() { + let mut blocks = HashMap::new(); + let mut numbers = HashMap::new(); + let mut parent_hash = B256::random(); + let mut block_builder = TestBlockBuilder::default(); + + for i in 1..=2 { + let block = block_builder.get_executed_block_with_number(i, parent_hash); + let hash = block.block().hash(); + blocks.insert(hash, Arc::new(BlockState::new(block.clone()))); + numbers.insert(i, hash); + parent_hash = hash; + } + + let pending_block = block_builder.get_executed_block_with_number(3, parent_hash); + let pending_state = BlockState::new(pending_block); + + let state = CanonicalInMemoryState::new(blocks, numbers, Some(pending_state), None); + let chain: Vec<_> = state.canonical_chain().collect(); + + assert_eq!(chain.len(), 3); + assert_eq!(chain[0].number(), 3); + assert_eq!(chain[1].number(), 2); + assert_eq!(chain[2].number(), 1); + } + #[test] fn test_block_state_parent_blocks() { let mut test_block_builder = TestBlockBuilder::default(); diff --git a/crates/storage/provider/src/providers/blockchain_provider.rs b/crates/storage/provider/src/providers/blockchain_provider.rs index a0aded2cee59..c35543b61650 100644 --- a/crates/storage/provider/src/providers/blockchain_provider.rs +++ b/crates/storage/provider/src/providers/blockchain_provider.rs @@ -26,7 +26,7 @@ use reth_stages_types::{StageCheckpoint, StageId}; use reth_storage_errors::provider::ProviderResult; use revm::primitives::{BlockEnv, CfgEnvWithHandlerCfg}; use std::{ - ops::{RangeBounds, RangeInclusive}, + ops::{Add, Bound, RangeBounds, RangeInclusive, Sub}, sync::Arc, time::Instant, }; @@ -95,6 +95,30 @@ where pub fn canonical_in_memory_state(&self) -> CanonicalInMemoryState { self.canonical_in_memory_state.clone() } + + // Helper function to convert range bounds + fn convert_range_bounds( + &self, + range: impl RangeBounds, + end_unbounded: impl FnOnce() -> T, + ) -> (T, T) + where + T: Copy + Add + Sub + From, + { + let start = match range.start_bound() { + Bound::Included(&n) => n, + Bound::Excluded(&n) => n + T::from(1u8), + Bound::Unbounded => T::from(0u8), + }; + + let end = match range.end_bound() { + Bound::Included(&n) => n, + Bound::Excluded(&n) => n - T::from(1u8), + Bound::Unbounded => end_unbounded(), + }; + + (start, end) + } } impl BlockchainProvider2 @@ -140,10 +164,18 @@ where DB: Database, { fn header(&self, block_hash: &BlockHash) -> ProviderResult> { + if let Some(block_state) = self.canonical_in_memory_state.state_by_hash(*block_hash) { + return Ok(Some(block_state.block().block().header.header().clone())); + } + self.database.header(block_hash) } fn header_by_number(&self, num: BlockNumber) -> ProviderResult> { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + return Ok(Some(block_state.block().block().header.header().clone())); + } + self.database.header_by_number(num) } @@ -156,10 +188,31 @@ where } fn headers_range(&self, range: impl RangeBounds) -> ProviderResult> { - self.database.headers_range(range) + let mut headers = Vec::new(); + let (start, end) = self.convert_range_bounds(range, || { + self.canonical_in_memory_state.get_canonical_block_number() + }); + + for num in start..=end { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + // TODO: there might be an update between loop iterations, we + // need to handle that situation. + headers.push(block_state.block().block().header.header().clone()); + } else { + let mut db_headers = self.database.headers_range(num..=end)?; + headers.append(&mut db_headers); + break; + } + } + + Ok(headers) } fn sealed_header(&self, number: BlockNumber) -> ProviderResult> { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(number) { + return Ok(Some(block_state.block().block().header.clone())); + } + self.database.sealed_header(number) } @@ -167,15 +220,53 @@ where &self, range: impl RangeBounds, ) -> ProviderResult> { - self.database.sealed_headers_range(range) + let mut sealed_headers = Vec::new(); + let (start, end) = self.convert_range_bounds(range, || { + self.canonical_in_memory_state.get_canonical_block_number() + }); + + for num in start..=end { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + // TODO: there might be an update between loop iterations, we + // need to handle that situation. + sealed_headers.push(block_state.block().block().header.clone()); + } else { + let mut db_headers = self.database.sealed_headers_range(num..=end)?; + sealed_headers.append(&mut db_headers); + break; + } + } + + Ok(sealed_headers) } fn sealed_headers_while( &self, range: impl RangeBounds, - predicate: impl FnMut(&SealedHeader) -> bool, + mut predicate: impl FnMut(&SealedHeader) -> bool, ) -> ProviderResult> { - self.database.sealed_headers_while(range, predicate) + let mut headers = Vec::new(); + let (start, end) = self.convert_range_bounds(range, || { + self.canonical_in_memory_state.get_canonical_block_number() + }); + + for num in start..=end { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + let header = block_state.block().block().header.clone(); + if !predicate(&header) { + break; + } + headers.push(header); + } else { + let mut db_headers = self.database.sealed_headers_while(num..=end, predicate)?; + // TODO: there might be an update between loop iterations, we + // need to handle that situation. + headers.append(&mut db_headers); + break; + } + } + + Ok(headers) } } @@ -184,6 +275,10 @@ where DB: Database, { fn block_hash(&self, number: u64) -> ProviderResult> { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(number) { + return Ok(Some(block_state.hash())); + } + self.database.block_hash(number) } @@ -192,7 +287,19 @@ where start: BlockNumber, end: BlockNumber, ) -> ProviderResult> { - self.database.canonical_hashes_range(start, end) + let mut hashes = Vec::with_capacity((end - start + 1) as usize); + for number in start..=end { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(number) { + hashes.push(block_state.hash()); + } else { + let mut db_hashes = self.database.canonical_hashes_range(number, end)?; + // TODO: there might be an update between loop iterations, we + // need to handle that situation. + hashes.append(&mut db_hashes); + break; + } + } + Ok(hashes) } } @@ -213,6 +320,10 @@ where } fn block_number(&self, hash: B256) -> ProviderResult> { + if let Some(block_state) = self.canonical_in_memory_state.state_by_hash(hash) { + return Ok(Some(block_state.number())); + } + self.database.block_number(hash) } } @@ -239,33 +350,32 @@ where DB: Database, { fn find_block_by_hash(&self, hash: B256, source: BlockSource) -> ProviderResult> { - let block = match source { + match source { BlockSource::Any | BlockSource::Canonical => { // check in memory first // Note: it's fine to return the unsealed block because the caller already has // the hash - let mut block = self - .canonical_in_memory_state - .state_by_hash(hash) - .map(|block_state| block_state.block().block().clone().unseal()); - - if block.is_none() { - block = self.database.block_by_hash(hash)?; + if let Some(block_state) = self.canonical_in_memory_state.state_by_hash(hash) { + return Ok(Some(block_state.block().block().clone().unseal())); } - block + self.database.find_block_by_hash(hash, source) } BlockSource::Pending => { - self.canonical_in_memory_state.pending_block().map(|block| block.unseal()) + Ok(self.canonical_in_memory_state.pending_block().map(|block| block.unseal())) } - }; - - Ok(block) + } } fn block(&self, id: BlockHashOrNumber) -> ProviderResult> { match id { BlockHashOrNumber::Hash(hash) => self.find_block_by_hash(hash, BlockSource::Any), - BlockHashOrNumber::Number(num) => self.database.block_by_number(num), + BlockHashOrNumber::Number(num) => { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + return Ok(Some(block_state.block().block().clone().unseal())); + } + + self.database.block_by_number(num) + } } } @@ -303,6 +413,22 @@ where id: BlockHashOrNumber, transaction_kind: TransactionVariant, ) -> ProviderResult> { + match id { + BlockHashOrNumber::Hash(hash) => { + if let Some(block_state) = self.canonical_in_memory_state.state_by_hash(hash) { + let block = block_state.block().block().clone(); + let senders = block_state.block().senders().clone(); + return Ok(Some(BlockWithSenders { block: block.unseal(), senders })); + } + } + BlockHashOrNumber::Number(num) => { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + let block = block_state.block().block().clone(); + let senders = block_state.block().senders().clone(); + return Ok(Some(BlockWithSenders { block: block.unseal(), senders })); + } + } + } self.database.block_with_senders(id, transaction_kind) } @@ -311,25 +437,88 @@ where id: BlockHashOrNumber, transaction_kind: TransactionVariant, ) -> ProviderResult> { + match id { + BlockHashOrNumber::Hash(hash) => { + if let Some(block_state) = self.canonical_in_memory_state.state_by_hash(hash) { + let block = block_state.block().block().clone(); + let senders = block_state.block().senders().clone(); + return Ok(Some(SealedBlockWithSenders { block, senders })); + } + } + BlockHashOrNumber::Number(num) => { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + let block = block_state.block().block().clone(); + let senders = block_state.block().senders().clone(); + return Ok(Some(SealedBlockWithSenders { block, senders })); + } + } + } self.database.sealed_block_with_senders(id, transaction_kind) } fn block_range(&self, range: RangeInclusive) -> ProviderResult> { - self.database.block_range(range) + let capacity = (range.end() - range.start() + 1) as usize; + let mut blocks = Vec::with_capacity(capacity); + + for num in range.clone() { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + // TODO: there might be an update between loop iterations, we + // need to handle that situation. + blocks.push(block_state.block().block().clone().unseal()); + } else { + let mut db_blocks = self.database.block_range(num..=*range.end())?; + blocks.append(&mut db_blocks); + break; + } + } + Ok(blocks) } fn block_with_senders_range( &self, range: RangeInclusive, ) -> ProviderResult> { - self.database.block_with_senders_range(range) + let capacity = (range.end() - range.start() + 1) as usize; + let mut blocks = Vec::with_capacity(capacity); + + for num in range.clone() { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + let block = block_state.block().block().clone(); + let senders = block_state.block().senders().clone(); + // TODO: there might be an update between loop iterations, we + // need to handle that situation. + blocks.push(BlockWithSenders { block: block.unseal(), senders }); + } else { + let mut db_blocks = self.database.block_with_senders_range(num..=*range.end())?; + blocks.append(&mut db_blocks); + break; + } + } + Ok(blocks) } fn sealed_block_with_senders_range( &self, range: RangeInclusive, ) -> ProviderResult> { - self.database.sealed_block_with_senders_range(range) + let capacity = (range.end() - range.start() + 1) as usize; + let mut blocks = Vec::with_capacity(capacity); + + for num in range.clone() { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(num) { + let block = block_state.block().block().clone(); + let senders = block_state.block().senders().clone(); + // TODO: there might be an update between loop iterations, we + // need to handle that situation. + blocks.push(SealedBlockWithSenders { block, senders }); + } else { + let mut db_blocks = + self.database.sealed_block_with_senders_range(num..=*range.end())?; + blocks.append(&mut db_blocks); + break; + } + } + Ok(blocks) } } @@ -353,6 +542,10 @@ where } fn transaction_by_hash(&self, hash: TxHash) -> ProviderResult> { + if let Some(tx) = self.canonical_in_memory_state.transaction_by_hash(hash) { + return Ok(Some(tx)) + } + self.database.transaction_by_hash(hash) } @@ -360,6 +553,12 @@ where &self, tx_hash: TxHash, ) -> ProviderResult> { + if let Some((tx, meta)) = + self.canonical_in_memory_state.transaction_by_hash_with_meta(tx_hash) + { + return Ok(Some((tx, meta))) + } + self.database.transaction_by_hash_with_meta(tx_hash) } @@ -371,6 +570,18 @@ where &self, id: BlockHashOrNumber, ) -> ProviderResult>> { + match id { + BlockHashOrNumber::Hash(hash) => { + if let Some(block_state) = self.canonical_in_memory_state.state_by_hash(hash) { + return Ok(Some(block_state.block().block().body.clone())); + } + } + BlockHashOrNumber::Number(number) => { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(number) { + return Ok(Some(block_state.block().block().body.clone())); + } + } + } self.database.transactions_by_block(id) } @@ -378,7 +589,35 @@ where &self, range: impl RangeBounds, ) -> ProviderResult>> { - self.database.transactions_by_block_range(range) + let (start, end) = self.convert_range_bounds(range, || { + self.canonical_in_memory_state.get_canonical_block_number() + }); + + let mut transactions = Vec::new(); + let mut last_in_memory_block = None; + + for number in start..=end { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(number) { + // TODO: there might be an update between loop iterations, we + // need to handle that situation. + transactions.push(block_state.block().block().body.clone()); + last_in_memory_block = Some(number); + } else { + break; + } + } + + if let Some(last_block) = last_in_memory_block { + if last_block < end { + let mut db_transactions = + self.database.transactions_by_block_range((last_block + 1)..=end)?; + transactions.append(&mut db_transactions); + } + } else { + transactions = self.database.transactions_by_block_range(start..=end)?; + } + + Ok(transactions) } fn transactions_by_tx_range( @@ -409,10 +648,41 @@ where } fn receipt_by_hash(&self, hash: TxHash) -> ProviderResult> { + for block_state in self.canonical_in_memory_state.canonical_chain() { + let executed_block = block_state.block(); + let block = executed_block.block(); + let receipts = block_state.executed_block_receipts(); + + // assuming 1:1 correspondence between transactions and receipts + debug_assert_eq!( + block.body.len(), + receipts.len(), + "Mismatch between transaction and receipt count" + ); + + if let Some(tx_index) = block.body.iter().position(|tx| tx.hash() == hash) { + // safe to use tx_index for receipts due to 1:1 correspondence + return Ok(receipts.get(tx_index).cloned()); + } + } + self.database.receipt_by_hash(hash) } fn receipts_by_block(&self, block: BlockHashOrNumber) -> ProviderResult>> { + match block { + BlockHashOrNumber::Hash(hash) => { + if let Some(block_state) = self.canonical_in_memory_state.state_by_hash(hash) { + return Ok(Some(block_state.executed_block_receipts())); + } + } + BlockHashOrNumber::Number(number) => { + if let Some(block_state) = self.canonical_in_memory_state.state_by_number(number) { + return Ok(Some(block_state.executed_block_receipts())); + } + } + } + self.database.receipts_by_block(block) }