Skip to content

Commit

Permalink
test(engine): enable state root task in engine unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fgimenez committed Feb 20, 2025
1 parent 0b708de commit b6a3f1c
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 54 deletions.
7 changes: 5 additions & 2 deletions crates/engine/local/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use reth_engine_tree::{
RequestHandlerEvent,
},
persistence::PersistenceHandle,
tree::{EngineApiTreeHandler, InvalidBlockHook, TreeConfig},
tree::{root::BasicStateRootTaskFactory, EngineApiTreeHandler, InvalidBlockHook, TreeConfig},
};
use reth_evm::{execute::BlockExecutorProvider, ConfigureEvm};
use reth_node_types::{BlockTy, HeaderTy, TxTy};
Expand Down Expand Up @@ -95,8 +95,10 @@ where
PersistenceHandle::<N::Primitives>::spawn_service(provider, pruner, sync_metrics_tx);
let canonical_in_memory_state = blockchain_db.canonical_in_memory_state();

let state_root_task_factory = BasicStateRootTaskFactory::new();

let (to_tree_tx, from_tree) =
EngineApiTreeHandler::<N::Primitives, _, _, _, _, _>::spawn_new(
EngineApiTreeHandler::<N::Primitives, _, _, _, _, _, _>::spawn_new(
blockchain_db.clone(),
executor_factory,
consensus,
Expand All @@ -108,6 +110,7 @@ where
invalid_block_hook,
engine_kind,
evm_config,
state_root_task_factory,
);

let handler = EngineApiRequestHandler::new(to_tree_tx, from_tree);
Expand Down
7 changes: 5 additions & 2 deletions crates/engine/service/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use reth_engine_tree::{
download::BasicBlockDownloader,
engine::{EngineApiKind, EngineApiRequest, EngineApiRequestHandler, EngineHandler},
persistence::PersistenceHandle,
tree::{EngineApiTreeHandler, InvalidBlockHook, TreeConfig},
tree::{root::BasicStateRootTaskFactory, EngineApiTreeHandler, InvalidBlockHook, TreeConfig},
};
pub use reth_engine_tree::{
chain::{ChainEvent, ChainOrchestrator},
Expand Down Expand Up @@ -105,8 +105,10 @@ where

let canonical_in_memory_state = blockchain_db.canonical_in_memory_state();

let state_root_task_factory = BasicStateRootTaskFactory::new();

let (to_tree_tx, from_tree) =
EngineApiTreeHandler::<N::Primitives, _, _, _, _, _>::spawn_new(
EngineApiTreeHandler::<N::Primitives, _, _, _, _, _, _>::spawn_new(
blockchain_db,
executor_factory,
consensus,
Expand All @@ -118,6 +120,7 @@ where
invalid_block_hook,
engine_kind,
evm_config,
state_root_task_factory,
);

let engine_handler = EngineApiRequestHandler::new(to_tree_tx, from_tree);
Expand Down
154 changes: 115 additions & 39 deletions crates/engine/tree/src/tree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@ use crate::{
chain::FromOrchestrator,
engine::{DownloadRequest, EngineApiEvent, EngineApiKind, EngineApiRequest, FromEngine},
persistence::PersistenceHandle,
tree::{
cached_state::{CachedStateMetrics, CachedStateProvider, ProviderCacheBuilder},
metrics::EngineApiMetrics,
},
};
use alloy_consensus::{transaction::Recovered, BlockHeader};
use alloy_eips::BlockNumHash;
Expand All @@ -18,9 +14,11 @@ use alloy_primitives::{
use alloy_rpc_types_engine::{
ForkchoiceState, PayloadStatus, PayloadStatusEnum, PayloadValidationError,
};
use cached_state::{ProviderCaches, SavedCache};
use cached_state::{
CachedStateMetrics, CachedStateProvider, ProviderCacheBuilder, ProviderCaches, SavedCache,
};
use error::{InsertBlockError, InsertBlockErrorKind, InsertBlockFatalError};
use metrics::PrewarmThreadMetrics;
use metrics::{EngineApiMetrics, PrewarmThreadMetrics};
use persistence_state::CurrentPersistenceAction;
use reth_chain_state::{
CanonicalInMemoryState, ExecutedBlock, ExecutedBlockWithTrieUpdates,
Expand Down Expand Up @@ -59,7 +57,8 @@ use reth_trie::{
use reth_trie_db::DatabaseTrieCursorFactory;
use reth_trie_parallel::root::{ParallelStateRoot, ParallelStateRootError};
use root::{
StateRootComputeOutcome, StateRootConfig, StateRootHandle, StateRootMessage, StateRootTask,
StateRootComputeHandle, StateRootComputeOutcome, StateRootConfig, StateRootMessage,
StateRootTaskFactory, StateRootTaskRunner,
};
use std::{
cmp::Ordering,
Expand Down Expand Up @@ -556,10 +555,11 @@ pub enum TreeAction {
///
/// This type is responsible for processing engine API requests, maintaining the canonical state and
/// emitting events.
pub struct EngineApiTreeHandler<N, P, E, T, V, C>
pub struct EngineApiTreeHandler<N, P, E, T, V, C, F>
where
N: NodePrimitives,
T: EngineTypes,
F: StateRootTaskFactory<P>,
{
provider: P,
executor_provider: E,
Expand Down Expand Up @@ -603,14 +603,15 @@ where
engine_kind: EngineApiKind,
/// The most recent cache used for execution.
most_recent_cache: Option<SavedCache>,
/// Thread pool used for the state root task and prewarming
thread_pool: Arc<rayon::ThreadPool>,
/// Factory for state root tasks.
state_root_task_factory: F,
}

impl<N, P: Debug, E: Debug, T: EngineTypes + Debug, V: Debug, C: Debug> std::fmt::Debug
for EngineApiTreeHandler<N, P, E, T, V, C>
impl<N, P: Debug, E: Debug, T: EngineTypes + Debug, V: Debug, C: Debug, F> std::fmt::Debug
for EngineApiTreeHandler<N, P, E, T, V, C, F>
where
N: NodePrimitives,
F: StateRootTaskFactory<P>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EngineApiTreeHandler")
Expand All @@ -634,7 +635,7 @@ where
}
}

impl<N, P, E, T, V, C> EngineApiTreeHandler<N, P, E, T, V, C>
impl<N, P, E, T, V, C, F> EngineApiTreeHandler<N, P, E, T, V, C, F>
where
N: NodePrimitives,
P: DatabaseProviderFactory
Expand All @@ -651,6 +652,7 @@ where
C: ConfigureEvm<Header = N::BlockHeader, Transaction = N::SignedTx> + 'static,
T: EngineTypes,
V: EngineValidator<T, Block = N::Block>,
F: StateRootTaskFactory<P>,
{
/// Creates a new [`EngineApiTreeHandler`].
#[expect(clippy::too_many_arguments)]
Expand All @@ -668,19 +670,10 @@ where
config: TreeConfig,
engine_kind: EngineApiKind,
evm_config: C,
state_root_task_factory: F,
) -> Self {
let (incoming_tx, incoming) = std::sync::mpsc::channel();

let num_threads = root::rayon_thread_pool_size();

let thread_pool = Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(num_threads)
.thread_name(|i| format!("srt-worker-{}", i))
.build()
.expect("Failed to create proof worker thread pool"),
);

Self {
provider,
executor_provider,
Expand All @@ -701,7 +694,7 @@ where
invalid_block_hook: Box::new(NoopInvalidBlockHook),
engine_kind,
most_recent_cache: None,
thread_pool,
state_root_task_factory,
}
}

Expand All @@ -728,6 +721,7 @@ where
invalid_block_hook: Box<dyn InvalidBlockHook<N>>,
kind: EngineApiKind,
evm_config: C,
state_root_task_factory: F,
) -> (Sender<FromEngine<EngineApiRequest<T, N>, N::Block>>, UnboundedReceiver<EngineApiEvent<N>>)
{
let best_block_number = provider.best_block_number().unwrap_or(0);
Expand Down Expand Up @@ -760,6 +754,7 @@ where
config,
kind,
evm_config,
state_root_task_factory,
);
task.set_invalid_block_hook(invalid_block_hook);
let incoming = task.incoming_tx.clone();
Expand Down Expand Up @@ -2459,9 +2454,9 @@ where
.set(config_elapsed.as_secs_f64());

let state_root_task =
StateRootTask::new(state_root_config.clone(), self.thread_pool.clone());
self.state_root_task_factory.create_task(state_root_config.clone());
let state_root_sender = state_root_task.state_root_message_sender();
let state_hook = Box::new(state_root_task.state_hook()) as Box<dyn OnStateHook>;
let state_hook = state_root_task.state_hook() as Box<dyn OnStateHook>;
(
Some(state_root_task.spawn()),
Some(state_root_config),
Expand Down Expand Up @@ -2741,7 +2736,7 @@ where
let evm_config = self.evm_config.clone();

// spawn task executing the individual tx
self.thread_pool.spawn(move || {
self.state_root_task_factory.thread_pool().spawn(move || {
let thread_start = Instant::now();
let in_progress = task_finished.read().unwrap();

Expand Down Expand Up @@ -2891,11 +2886,11 @@ where
))
}

/// Waits for the result on the input [`StateRootHandle`], and handles it, falling back to
/// Waits for the result on the input state root handle, and handles it, falling back to
/// the hash builder-based state root calculation if it fails.
fn handle_state_root_result(
&self,
state_root_handle: StateRootHandle,
state_root_handle: <F::Runner as StateRootTaskRunner>::ResultHandle,
state_root_task_config: StateRootConfig<P>,
sealed_block: &SealedBlock<N::Block>,
hashed_state: &HashedPostState,
Expand Down Expand Up @@ -3232,7 +3227,10 @@ pub enum InsertPayloadOk {

#[cfg(test)]
mod tests {
use super::*;
use super::{
root::{StateRootComputeHandle, StateRootTaskRunner},
*,
};
use crate::persistence::PersistenceAction;
use alloy_consensus::Header;
use alloy_primitives::Bytes;
Expand All @@ -3248,11 +3246,12 @@ mod tests {
use reth_ethereum_consensus::EthBeaconConsensus;
use reth_ethereum_engine_primitives::{EthEngineTypes, EthereumEngineValidator};
use reth_ethereum_primitives::{Block, EthPrimitives};
use reth_evm::test_utils::MockExecutorProvider;
use reth_evm::{system_calls::StateChangeSource, test_utils::MockExecutorProvider};
use reth_evm_ethereum::EthEvmConfig;
use reth_primitives_traits::Block as _;
use reth_provider::test_utils::MockEthProvider;
use reth_trie::{updates::TrieUpdates, HashedPostState};
use revm_state::EvmState;
use std::{
str::FromStr,
sync::mpsc::{channel, Sender},
Expand Down Expand Up @@ -3313,6 +3312,78 @@ mod tests {
}
}

struct MockStateRootHandle {
root: B256,
}

impl StateRootComputeHandle for MockStateRootHandle {
fn wait_for_result(self) -> Result<StateRootComputeOutcome, ParallelStateRootError> {
Ok(StateRootComputeOutcome {
state_root: (self.root, TrieUpdates::default()),
total_time: Duration::from_secs(0),
time_from_last_update: Duration::from_secs(0),
})
}
}

struct MockStateRootTask {
root: B256,
}

impl MockStateRootTask {
fn new(root: B256) -> Self {
Self { root }
}
}

impl StateRootTaskRunner for MockStateRootTask {
type ResultHandle = MockStateRootHandle;

fn spawn(self) -> Self::ResultHandle {
MockStateRootHandle { root: self.root }
}

fn state_hook(&self) -> Box<dyn OnStateHook> {
Box::new(move |_: StateChangeSource, _: &EvmState| {})
}

fn state_root_message_sender(&self) -> Sender<StateRootMessage> {
let (tx, _rx) = channel();
tx
}
}

struct MockStateRootTaskFactory {
roots: Vec<B256>,
thread_pool: Arc<rayon::ThreadPool>,
}

impl MockStateRootTaskFactory {
fn new() -> Self {
let num_threads = root::rayon_thread_pool_size();
let thread_pool =
Arc::new(rayon::ThreadPoolBuilder::new().num_threads(num_threads).build().unwrap());

Self { roots: Vec::new(), thread_pool }
}

fn add_state_root(&mut self, root: B256) {
self.roots.push(root);
}
}

impl<Provider> StateRootTaskFactory<Provider> for MockStateRootTaskFactory {
type Runner = MockStateRootTask;

fn create_task(&mut self, _config: StateRootConfig<Provider>) -> Self::Runner {
MockStateRootTask::new(self.roots.pop().unwrap())
}

fn thread_pool(&self) -> Arc<rayon::ThreadPool> {
self.thread_pool.clone()
}
}

struct TestHarness {
tree: EngineApiTreeHandler<
EthPrimitives,
Expand All @@ -3321,6 +3392,7 @@ mod tests {
EthEngineTypes,
EthereumEngineValidator,
EthEvmConfig,
MockStateRootTaskFactory,
>,
to_tree_tx: Sender<FromEngine<EngineApiRequest<EthEngineTypes, EthPrimitives>, Block>>,
from_tree_rx: UnboundedReceiver<EngineApiEvent>,
Expand Down Expand Up @@ -3353,6 +3425,11 @@ mod tests {
let consensus = Arc::new(EthBeaconConsensus::new(chain_spec.clone()));

let provider = MockEthProvider::default();
let mut block_builder =
TestBlockBuilder::default().with_chain_spec((*chain_spec).clone());
let genesis_block = block_builder.get_executed_block_with_number(0, B256::random());
provider.add_block(B256::default(), genesis_block.into_sealed_block().clone_block());

let executor_provider = MockExecutorProvider::default();

let payload_validator = EthereumEngineValidator::new(chain_spec.clone());
Expand All @@ -3367,7 +3444,9 @@ mod tests {
let (to_payload_service, _payload_command_rx) = unbounded_channel();
let payload_builder = PayloadBuilderHandle::new(to_payload_service);

let evm_config = EthEvmConfig::new(chain_spec.clone());
let evm_config = EthEvmConfig::new(chain_spec);

let state_root_task_factory = MockStateRootTaskFactory::new();

let tree = EngineApiTreeHandler::new(
provider.clone(),
Expand All @@ -3380,16 +3459,13 @@ mod tests {
persistence_handle,
PersistenceState::default(),
payload_builder,
// TODO: fix tests for state root task https://github.com/paradigmxyz/reth/issues/14376
// always assume enough parallelism for tests
TreeConfig::default()
.with_legacy_state_root(true)
.with_has_enough_parallelism(true),
TreeConfig::default().with_has_enough_parallelism(true),
EngineApiKind::Ethereum,
evm_config,
state_root_task_factory,
);

let block_builder = TestBlockBuilder::default().with_chain_spec((*chain_spec).clone());
Self {
to_tree_tx: tree.incoming_tx.clone(),
tree,
Expand Down Expand Up @@ -3463,7 +3539,7 @@ mod tests {
) -> Result<InsertPayloadOk, InsertBlockError<Block>> {
let execution_outcome = self.block_builder.get_execution_outcome(block.clone());
self.extend_execution_outcome([execution_outcome]);
self.tree.provider.add_state_root(block.state_root);
self.tree.state_root_task_factory.add_state_root(block.state_root);
self.tree.insert_block(block)
}

Expand Down Expand Up @@ -3666,7 +3742,7 @@ mod tests {
} else {
block.state_root
};
self.tree.provider.add_state_root(state_root);
self.tree.state_root_task_factory.add_state_root(state_root);
execution_outcomes.push(execution_outcome);
}
self.extend_execution_outcome(execution_outcomes);
Expand Down
Loading

0 comments on commit b6a3f1c

Please sign in to comment.