diff --git a/crates/stages/api/src/pipeline/mod.rs b/crates/stages/api/src/pipeline/mod.rs index 928c43fb6..ae93ace1e 100644 --- a/crates/stages/api/src/pipeline/mod.rs +++ b/crates/stages/api/src/pipeline/mod.rs @@ -1,15 +1,16 @@ mod ctrl; mod event; pub use crate::pipeline::ctrl::ControlFlow; -use crate::{PipelineTarget, StageCheckpoint, StageId}; +use crate::{LatestStateProviderFactory, PipelineTarget, StageCheckpoint, StageId}; use alloy_primitives::{BlockNumber, B256}; pub use event::*; use futures_util::Future; use reth_primitives_traits::constants::BEACON_CONSENSUS_REORG_UNWIND_DEPTH; use reth_provider::{ providers::ProviderNodeTypes, writer::UnifiedStorageWriter, DatabaseProviderFactory, - FinalizedBlockReader, FinalizedBlockWriter, ProviderFactory, StageCheckpointReader, - StageCheckpointWriter, StaticFileProviderFactory, + FinalizedBlockReader, FinalizedBlockWriter, ProviderFactory, ProviderResult, + StageCheckpointReader, StageCheckpointWriter, StateProviderBox, StateProviderOptions, + StaticFileProviderFactory, }; use reth_prune::PrunerBuilder; use reth_static_file::StaticFileProducer; @@ -107,6 +108,12 @@ impl Pipeline { } } +impl LatestStateProviderFactory for ProviderFactory { + fn latest(&self, opts: StateProviderOptions) -> ProviderResult { + self.latest(opts) + } +} + impl Pipeline { /// Registers progress metrics for each registered stage pub fn register_metrics(&mut self) -> Result<(), PipelineError> { @@ -435,7 +442,7 @@ impl Pipeline { target, }); - match stage.execute(&provider_rw, exec_input) { + match stage.execute_v2(&provider_rw, &self.provider_factory, exec_input) { Ok(out @ ExecOutput { checkpoint, done }) => { made_progress |= checkpoint.block_number != prev_checkpoint.unwrap_or_default().block_number; diff --git a/crates/stages/api/src/stage.rs b/crates/stages/api/src/stage.rs index 1e201aee6..1217668dc 100644 --- a/crates/stages/api/src/stage.rs +++ b/crates/stages/api/src/stage.rs @@ -1,6 +1,8 @@ use crate::{error::StageError, StageCheckpoint, StageId}; use alloy_primitives::{BlockNumber, TxNumber}; -use reth_provider::{BlockReader, ProviderError}; +use reth_provider::{ + BlockReader, ProviderError, ProviderResult, StateProviderBox, StateProviderOptions, +}; use std::{ cmp::{max, min}, future::{poll_fn, Future}, @@ -178,6 +180,12 @@ pub struct UnwindOutput { pub checkpoint: StageCheckpoint, } +/// A factory for creating latest block state provider. +pub trait LatestStateProviderFactory { + /// Create state provider for latest block + fn latest(&self, opts: StateProviderOptions) -> ProviderResult; +} + /// A stage is a segmented part of the syncing process of the node. /// /// Each stage takes care of a well-defined task, such as downloading headers or executing @@ -233,6 +241,16 @@ pub trait Stage: Send + Sync { /// upon invoking this method. fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result; + /// Execute the stage. + fn execute_v2( + &mut self, + provider: &Provider, + _factory: &dyn LatestStateProviderFactory, + input: ExecInput, + ) -> Result { + self.execute(provider, input) + } + /// Post execution commit hook. /// /// This is called after the stage has been executed and the data has been committed by the diff --git a/crates/stages/stages/src/stages/execution.rs b/crates/stages/stages/src/stages/execution.rs index c7ea70f69..495474fda 100644 --- a/crates/stages/stages/src/stages/execution.rs +++ b/crates/stages/stages/src/stages/execution.rs @@ -15,18 +15,19 @@ use reth_provider::{ providers::{StaticFileProvider, StaticFileProviderRWRefMut, StaticFileWriter}, writer::UnifiedStorageWriter, BlockReader, DBProvider, HeaderProvider, LatestStateProviderRef, OriginalValuesKnown, - ProviderError, StateChangeWriter, StateWriter, StaticFileProviderFactory, StatsReader, - TransactionVariant, + ProviderError, StateChangeWriter, StateProvider, StateProviderOptions, StateWriter, + StaticFileProviderFactory, StatsReader, TransactionVariant, }; use reth_prune_types::PruneModes; use reth_revm::database::StateProviderDatabase; use reth_stages_api::{ BlockErrorKind, CheckpointBlockRange, EntitiesCheckpoint, ExecInput, ExecOutput, - ExecutionCheckpoint, ExecutionStageThresholds, Stage, StageCheckpoint, StageError, StageId, - UnwindInput, UnwindOutput, + ExecutionCheckpoint, ExecutionStageThresholds, LatestStateProviderFactory, Stage, + StageCheckpoint, StageError, StageId, UnwindInput, UnwindOutput, }; use std::{ cmp::Ordering, + num::NonZero, ops::RangeInclusive, sync::Arc, task::{ready, Context, Poll}, @@ -194,6 +195,133 @@ where /// Execute the stage fn execute(&mut self, provider: &Provider, input: ExecInput) -> Result { + self.execute_inner(provider, None, input) + } + + fn execute_v2( + &mut self, + provider: &Provider, + factory: &dyn LatestStateProviderFactory, + input: ExecInput, + ) -> Result { + self.execute_inner(provider, Some(factory), input) + } + + fn post_execute_commit(&mut self) -> Result<(), StageError> { + let Some(chain) = self.post_execute_commit_input.take() else { return Ok(()) }; + + // NOTE: We can ignore the error here, since an error means that the channel is closed, + // which means the manager has died, which then in turn means the node is shutting down. + let _ = self + .exex_manager_handle + .send(ExExNotification::ChainCommitted { new: Arc::new(chain) }); + + Ok(()) + } + + /// Unwind the stage. + fn unwind( + &mut self, + provider: &Provider, + input: UnwindInput, + ) -> Result { + let (range, unwind_to, _) = + input.unwind_block_range_with_threshold(self.thresholds.max_blocks.unwrap_or(u64::MAX)); + if range.is_empty() { + return Ok(UnwindOutput { + checkpoint: input.checkpoint.with_block_number(input.unwind_to), + }) + } + + // Unwind account and storage changesets, as well as receipts. + // + // This also updates `PlainStorageState` and `PlainAccountState`. + let bundle_state_with_receipts = provider.take_state(range.clone())?; + + // Prepare the input for post unwind commit hook, where an `ExExNotification` will be sent. + if self.exex_manager_handle.has_exexs() { + // Get the blocks for the unwound range. + let blocks = provider.sealed_block_with_senders_range(range.clone())?; + let previous_input = self.post_unwind_commit_input.replace(Chain::new( + blocks, + bundle_state_with_receipts, + None, + )); + + debug_assert!( + previous_input.is_none(), + "Previous post unwind commit input wasn't processed" + ); + if let Some(previous_input) = previous_input { + tracing::debug!(target: "sync::stages::execution", ?previous_input, "Previous post unwind commit input wasn't processed"); + } + } + + let static_file_provider = provider.static_file_provider(); + + // Unwind all receipts for transactions in the block range + if self.prune_modes.receipts.is_none() && self.prune_modes.receipts_log_filter.is_empty() { + // We only use static files for Receipts, if there is no receipt pruning of any kind. + + // prepare_static_file_producer does a consistency check that will unwind static files + // if the expected highest receipt in the files is higher than the database. + // Which is essentially what happens here when we unwind this stage. + let _static_file_producer = + prepare_static_file_producer(provider, &static_file_provider, *range.start())?; + } else { + // If there is any kind of receipt pruning/filtering we use the database, since static + // files do not support filters. + // + // If we hit this case, the receipts have already been unwound by the call to + // `take_state`. + } + + // Update the checkpoint. + let mut stage_checkpoint = input.checkpoint.execution_stage_checkpoint(); + if let Some(stage_checkpoint) = stage_checkpoint.as_mut() { + for block_number in range { + stage_checkpoint.progress.processed -= provider + .block_by_number(block_number)? + .ok_or_else(|| ProviderError::HeaderNotFound(block_number.into()))? + .gas_used; + } + } + let checkpoint = if let Some(stage_checkpoint) = stage_checkpoint { + StageCheckpoint::new(unwind_to).with_execution_stage_checkpoint(stage_checkpoint) + } else { + StageCheckpoint::new(unwind_to) + }; + + Ok(UnwindOutput { checkpoint }) + } + + fn post_unwind_commit(&mut self) -> Result<(), StageError> { + let Some(chain) = self.post_unwind_commit_input.take() else { return Ok(()) }; + + // NOTE: We can ignore the error here, since an error means that the channel is closed, + // which means the manager has died, which then in turn means the node is shutting down. + let _ = + self.exex_manager_handle.send(ExExNotification::ChainReverted { old: Arc::new(chain) }); + + Ok(()) + } +} + +impl ExecutionStage +where + E: BlockExecutorProvider, +{ + fn execute_inner( + &mut self, + provider: &Provider, + factory: Option<&dyn LatestStateProviderFactory>, + input: ExecInput, + ) -> Result + where + Provider: + DBProvider + BlockReader + StaticFileProviderFactory + StatsReader + StateChangeWriter, + for<'a> UnifiedStorageWriter<'a, Provider, StaticFileProviderRWRefMut<'a>>: StateWriter, + { if input.target_reached() { return Ok(ExecOutput::done(input.checkpoint())) } @@ -218,16 +346,29 @@ where None }; - let db = StateProviderDatabase(LatestStateProviderRef::new( - provider.tx_ref(), - provider.static_file_provider(), - )); - let mut executor = - if let Some(parallel_provider) = self.executor_provider.try_into_parallel_provider() { - EitherBatchExecutor::Parallel(parallel_provider.batch_executor(Arc::new(db))) + let mut executor = if let Some(parallel_provider) = + self.executor_provider.try_into_parallel_provider() + { + let db: Arc = if let Some(factory) = factory { + Arc::new( + factory.latest(StateProviderOptions { parallel: NonZero::new(8).unwrap() })?, + ) } else { - EitherBatchExecutor::Sequential(self.executor_provider.batch_executor(db)) + Arc::new(LatestStateProviderRef::new( + provider.tx_ref(), + provider.static_file_provider(), + )) }; + EitherBatchExecutor::Parallel( + parallel_provider.batch_executor(StateProviderDatabase(db)), + ) + } else { + let db = StateProviderDatabase(LatestStateProviderRef::new( + provider.tx_ref(), + provider.static_file_provider(), + )); + EitherBatchExecutor::Sequential(self.executor_provider.batch_executor(db)) + }; executor.set_tip(max_block); executor.set_prune_modes(prune_modes); @@ -383,105 +524,6 @@ where done, }) } - - fn post_execute_commit(&mut self) -> Result<(), StageError> { - let Some(chain) = self.post_execute_commit_input.take() else { return Ok(()) }; - - // NOTE: We can ignore the error here, since an error means that the channel is closed, - // which means the manager has died, which then in turn means the node is shutting down. - let _ = self - .exex_manager_handle - .send(ExExNotification::ChainCommitted { new: Arc::new(chain) }); - - Ok(()) - } - - /// Unwind the stage. - fn unwind( - &mut self, - provider: &Provider, - input: UnwindInput, - ) -> Result { - let (range, unwind_to, _) = - input.unwind_block_range_with_threshold(self.thresholds.max_blocks.unwrap_or(u64::MAX)); - if range.is_empty() { - return Ok(UnwindOutput { - checkpoint: input.checkpoint.with_block_number(input.unwind_to), - }) - } - - // Unwind account and storage changesets, as well as receipts. - // - // This also updates `PlainStorageState` and `PlainAccountState`. - let bundle_state_with_receipts = provider.take_state(range.clone())?; - - // Prepare the input for post unwind commit hook, where an `ExExNotification` will be sent. - if self.exex_manager_handle.has_exexs() { - // Get the blocks for the unwound range. - let blocks = provider.sealed_block_with_senders_range(range.clone())?; - let previous_input = self.post_unwind_commit_input.replace(Chain::new( - blocks, - bundle_state_with_receipts, - None, - )); - - debug_assert!( - previous_input.is_none(), - "Previous post unwind commit input wasn't processed" - ); - if let Some(previous_input) = previous_input { - tracing::debug!(target: "sync::stages::execution", ?previous_input, "Previous post unwind commit input wasn't processed"); - } - } - - let static_file_provider = provider.static_file_provider(); - - // Unwind all receipts for transactions in the block range - if self.prune_modes.receipts.is_none() && self.prune_modes.receipts_log_filter.is_empty() { - // We only use static files for Receipts, if there is no receipt pruning of any kind. - - // prepare_static_file_producer does a consistency check that will unwind static files - // if the expected highest receipt in the files is higher than the database. - // Which is essentially what happens here when we unwind this stage. - let _static_file_producer = - prepare_static_file_producer(provider, &static_file_provider, *range.start())?; - } else { - // If there is any kind of receipt pruning/filtering we use the database, since static - // files do not support filters. - // - // If we hit this case, the receipts have already been unwound by the call to - // `take_state`. - } - - // Update the checkpoint. - let mut stage_checkpoint = input.checkpoint.execution_stage_checkpoint(); - if let Some(stage_checkpoint) = stage_checkpoint.as_mut() { - for block_number in range { - stage_checkpoint.progress.processed -= provider - .block_by_number(block_number)? - .ok_or_else(|| ProviderError::HeaderNotFound(block_number.into()))? - .gas_used; - } - } - let checkpoint = if let Some(stage_checkpoint) = stage_checkpoint { - StageCheckpoint::new(unwind_to).with_execution_stage_checkpoint(stage_checkpoint) - } else { - StageCheckpoint::new(unwind_to) - }; - - Ok(UnwindOutput { checkpoint }) - } - - fn post_unwind_commit(&mut self) -> Result<(), StageError> { - let Some(chain) = self.post_unwind_commit_input.take() else { return Ok(()) }; - - // NOTE: We can ignore the error here, since an error means that the channel is closed, - // which means the manager has died, which then in turn means the node is shutting down. - let _ = - self.exex_manager_handle.send(ExExNotification::ChainReverted { old: Arc::new(chain) }); - - Ok(()) - } } fn execution_checkpoint( diff --git a/crates/storage/provider/src/providers/blockchain_provider.rs b/crates/storage/provider/src/providers/blockchain_provider.rs index d38f582b4..e1947082b 100644 --- a/crates/storage/provider/src/providers/blockchain_provider.rs +++ b/crates/storage/provider/src/providers/blockchain_provider.rs @@ -1080,15 +1080,15 @@ impl ChainSpecProvider for BlockchainProvider2 { impl StateProviderFactory for BlockchainProvider2 { /// Storage provider for latest block - fn latest(&self) -> ProviderResult { + fn latest_with_opts(&self, opts: StateProviderOptions) -> ProviderResult { trace!(target: "providers::blockchain", "Getting latest block state provider"); // use latest state provider if the head state exists if let Some(state) = self.canonical_in_memory_state.head_state() { trace!(target: "providers::blockchain", "Using head state for latest state provider"); - Ok(self.block_state_provider(state, Default::default())?.boxed()) + Ok(self.block_state_provider(state, opts)?.boxed()) } else { trace!(target: "providers::blockchain", "Using database state for latest state provider"); - self.database.latest() + self.database.latest(opts) } } diff --git a/crates/storage/provider/src/providers/database/mod.rs b/crates/storage/provider/src/providers/database/mod.rs index c722b5130..0d373e4c2 100644 --- a/crates/storage/provider/src/providers/database/mod.rs +++ b/crates/storage/provider/src/providers/database/mod.rs @@ -158,9 +158,13 @@ impl ProviderFactory { /// State provider for latest block #[track_caller] - pub fn latest(&self) -> ProviderResult { + pub fn latest(&self, opts: StateProviderOptions) -> ProviderResult { trace!(target: "providers::db", "Returning latest state provider"); - Ok(Box::new(LatestStateProvider::new(self.db.tx()?, self.static_file_provider()))) + if opts.parallel.get() > 1 { + Ok(Box::new(ParallelStateProvider::try_new_latest(self, opts.parallel.get())?)) + } else { + Ok(Box::new(LatestStateProvider::new(self.db.tx()?, self.static_file_provider()))) + } } /// Storage provider for state at that given block @@ -665,7 +669,7 @@ mod tests { #[test] fn common_history_provider() { let factory = create_test_provider_factory(); - let _ = factory.latest(); + let _ = factory.latest(Default::default()); } #[test] diff --git a/crates/storage/provider/src/providers/database/parallel_provider.rs b/crates/storage/provider/src/providers/database/parallel_provider.rs index dc8425702..2a837c762 100644 --- a/crates/storage/provider/src/providers/database/parallel_provider.rs +++ b/crates/storage/provider/src/providers/database/parallel_provider.rs @@ -4,6 +4,7 @@ use std::{ }; use alloy_primitives::{Address, BlockNumber, Bytes, B256}; +use reth_db::Database; use reth_primitives::{Account, Bytecode, StorageKey, StorageValue}; use reth_storage_api::{ AccountReader, BlockHashReader, StateProofProvider, StateProvider, StateRootProvider, @@ -14,7 +15,7 @@ use reth_trie::{ updates::TrieUpdates, AccountProof, HashedPostState, HashedStorage, MultiProof, TrieInput, }; -use crate::providers::ProviderNodeTypes; +use crate::{providers::ProviderNodeTypes, LatestStateProvider, StaticFileProviderFactory}; use super::ProviderFactory; use flume as mpmc; @@ -24,34 +25,40 @@ enum StateProviderTask { Storage(Address, StorageKey, oneshot::Sender>>), BytecodeByHash(B256, oneshot::Sender>>), BasicAccount(Address, oneshot::Sender>>), + BlockHash(u64, oneshot::Sender>>), } impl StateProviderTask { fn process(self, state_provider: &dyn StateProvider) { match self { - StateProviderTask::Storage(address, key, tx) => { + Self::Storage(address, key, tx) => { let result = tokio::task::block_in_place(|| state_provider.storage(address, key)); let _ = tx.send(result); } - StateProviderTask::BytecodeByHash(code_hash, tx) => { + Self::BytecodeByHash(code_hash, tx) => { let result = tokio::task::block_in_place(|| state_provider.bytecode_by_hash(code_hash)); let _ = tx.send(result); } - StateProviderTask::BasicAccount(address, tx) => { + Self::BasicAccount(address, tx) => { let result = tokio::task::block_in_place(|| state_provider.basic_account(address)); let _ = tx.send(result); } + Self::BlockHash(block_number, tx) => { + let result = + tokio::task::block_in_place(|| state_provider.block_hash(block_number)); + let _ = tx.send(result); + } } } } -pub struct ParallelStateProvider { +pub(super) struct ParallelStateProvider { task_tx: mpmc::Sender, } impl ParallelStateProvider { - pub fn try_new( + pub(super) fn try_new( db: &ProviderFactory, block_number: u64, parallel: usize, @@ -66,6 +73,33 @@ impl ParallelStateProvider { for _ in 0..parallel { let state_provider = db.provider()?.try_into_history_at_block(block_number)?; let task_rx = task_rx.clone(); + // TODO: use individual tokio runtime + tokio::spawn(async move { + while let Ok(task) = task_rx.recv_async().await { + task.process(state_provider.as_ref()); + } + }); + } + + Ok(Self { task_tx }) + } + + pub(super) fn try_new_latest( + db: &ProviderFactory, + parallel: usize, + ) -> ProviderResult + where + N: ProviderNodeTypes, + { + assert!(parallel > 1, "parallel must be greater than 1"); + + let (task_tx, task_rx) = mpmc::unbounded::(); + + for _ in 0..parallel { + let state_provider = + Box::new(LatestStateProvider::new(db.db_ref().tx()?, db.static_file_provider())); + let task_rx = task_rx.clone(); + // TODO: use individual tokio runtime tokio::spawn(async move { while let Ok(task) = task_rx.recv_async().await { task.process(state_provider.as_ref()); @@ -81,20 +115,22 @@ impl StateProvider for ParallelStateProvider { fn storage(&self, address: Address, key: StorageKey) -> ProviderResult> { let (tx, rx) = oneshot::channel(); let _ = self.task_tx.send(StateProviderTask::Storage(address, key, tx)); - rx.blocking_recv().unwrap() + tokio::task::block_in_place(|| rx.blocking_recv().unwrap()) } fn bytecode_by_hash(&self, code_hash: B256) -> ProviderResult> { let (tx, rx) = oneshot::channel(); let _ = self.task_tx.send(StateProviderTask::BytecodeByHash(code_hash, tx)); - rx.blocking_recv().unwrap() + tokio::task::block_in_place(|| rx.blocking_recv().unwrap()) } } #[allow(unused)] impl BlockHashReader for ParallelStateProvider { fn block_hash(&self, block_number: u64) -> ProviderResult> { - todo!() + let (tx, rx) = oneshot::channel(); + let _ = self.task_tx.send(StateProviderTask::BlockHash(block_number, tx)); + tokio::task::block_in_place(|| rx.blocking_recv().unwrap()) } fn canonical_hashes_range( @@ -110,7 +146,7 @@ impl AccountReader for ParallelStateProvider { fn basic_account(&self, address: Address) -> ProviderResult> { let (tx, rx) = oneshot::channel(); let _ = self.task_tx.send(StateProviderTask::BasicAccount(address, tx)); - rx.blocking_recv().unwrap() + tokio::task::block_in_place(|| rx.blocking_recv().unwrap()) } } diff --git a/crates/storage/provider/src/providers/mod.rs b/crates/storage/provider/src/providers/mod.rs index 81f35ea89..81659ae47 100644 --- a/crates/storage/provider/src/providers/mod.rs +++ b/crates/storage/provider/src/providers/mod.rs @@ -592,9 +592,9 @@ impl ChainSpecProvider for BlockchainProvider { impl StateProviderFactory for BlockchainProvider { /// Storage provider for latest block - fn latest(&self) -> ProviderResult { + fn latest_with_opts(&self, opts: StateProviderOptions) -> ProviderResult { trace!(target: "providers::blockchain", "Getting latest block state provider"); - self.database.latest() + self.database.latest(opts) } fn history_by_block_number( diff --git a/crates/storage/storage-api/src/state.rs b/crates/storage/storage-api/src/state.rs index 12df38557..aee37f44a 100644 --- a/crates/storage/storage-api/src/state.rs +++ b/crates/storage/storage-api/src/state.rs @@ -129,7 +129,12 @@ impl Default for StateProviderOptions { #[auto_impl(&, Arc, Box)] pub trait StateProviderFactory: BlockIdReader + Send + Sync { /// Storage provider for latest block. - fn latest(&self) -> ProviderResult; + fn latest(&self) -> ProviderResult { + self.latest_with_opts(StateProviderOptions::default()) + } + + /// See `latest` + fn latest_with_opts(&self, opts: StateProviderOptions) -> ProviderResult; /// Returns a [`StateProvider`] indexed by the given [`BlockId`]. ///