Skip to content

Commit

Permalink
chore: use DatabaseProviderRW instead of TX on stages (#9451)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshieDo authored Jul 11, 2024
1 parent 11c5e31 commit 112b233
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 28 deletions.
7 changes: 4 additions & 3 deletions crates/stages/stages/src/stages/headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ where
/// database table.
fn write_headers<DB: Database>(
&mut self,
tx: &<DB as Database>::TXMut,
provider: &DatabaseProviderRW<DB>,
static_file_provider: StaticFileProvider,
) -> Result<BlockNumber, StageError> {
let total_headers = self.header_collector.len();
Expand Down Expand Up @@ -143,7 +143,8 @@ where

info!(target: "sync::stages::headers", total = total_headers, "Writing headers hash index");

let mut cursor_header_numbers = tx.cursor_write::<RawTable<tables::HeaderNumbers>>()?;
let mut cursor_header_numbers =
provider.tx_ref().cursor_write::<RawTable<tables::HeaderNumbers>>()?;
let mut first_sync = false;

// If we only have the genesis block hash, then we are at first sync, and we can remove it,
Expand Down Expand Up @@ -281,7 +282,7 @@ where
// Write the headers and related tables to DB from ETL space
let to_be_processed = self.hash_collector.len() as u64;
let last_header_number =
self.write_headers::<DB>(provider.tx_ref(), provider.static_file_provider().clone())?;
self.write_headers(provider, provider.static_file_provider().clone())?;

// Clear ETL collectors
self.hash_collector.clear();
Expand Down
4 changes: 2 additions & 2 deletions crates/stages/stages/src/stages/index_account_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl<DB: Database> Stage<DB> for IndexAccountHistoryStage {
info!(target: "sync::stages::index_account_history::exec", ?first_sync, "Collecting indices");
let collector =
collect_history_indices::<_, tables::AccountChangeSets, tables::AccountsHistory, _>(
provider.tx_ref(),
provider,
range.clone(),
ShardedKey::new,
|(index, value)| (index, value.address),
Expand All @@ -112,7 +112,7 @@ impl<DB: Database> Stage<DB> for IndexAccountHistoryStage {

info!(target: "sync::stages::index_account_history::exec", "Loading indices into database");
load_history_indices::<_, tables::AccountsHistory, _>(
provider.tx_ref(),
provider,
collector,
first_sync,
ShardedKey::new,
Expand Down
4 changes: 2 additions & 2 deletions crates/stages/stages/src/stages/index_storage_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl<DB: Database> Stage<DB> for IndexStorageHistoryStage {
info!(target: "sync::stages::index_storage_history::exec", ?first_sync, "Collecting indices");
let collector =
collect_history_indices::<_, tables::StorageChangeSets, tables::StoragesHistory, _>(
provider.tx_ref(),
provider,
BlockNumberAddress::range(range.clone()),
|AddressStorageKey((address, storage_key)), highest_block_number| {
StorageShardedKey::new(address, storage_key, highest_block_number)
Expand All @@ -117,7 +117,7 @@ impl<DB: Database> Stage<DB> for IndexStorageHistoryStage {

info!(target: "sync::stages::index_storage_history::exec", "Loading indices into database");
load_history_indices::<_, tables::StoragesHistory, _>(
provider.tx_ref(),
provider,
collector,
first_sync,
|AddressStorageKey((address, storage_key)), highest_block_number| {
Expand Down
22 changes: 11 additions & 11 deletions crates/stages/stages/src/stages/sender_recovery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,8 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
})
}

let tx = provider.tx_ref();

// Acquire the cursor for inserting elements
let mut senders_cursor = tx.cursor_write::<tables::TransactionSenders>()?;
let mut senders_cursor = provider.tx_ref().cursor_write::<tables::TransactionSenders>()?;

info!(target: "sync::stages::sender_recovery", ?tx_range, "Recovering senders");

Expand All @@ -98,7 +96,7 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
.collect::<Vec<Range<u64>>>();

for range in batch {
recover_range(range, provider, tx, &mut senders_cursor)?;
recover_range(range, provider, &mut senders_cursor)?;
}

Ok(ExecOutput {
Expand Down Expand Up @@ -130,14 +128,15 @@ impl<DB: Database> Stage<DB> for SenderRecoveryStage {
}
}

fn recover_range<DB: Database>(
fn recover_range<DB, CURSOR>(
tx_range: Range<u64>,
provider: &DatabaseProviderRW<DB>,
tx: &<DB as Database>::TXMut,
senders_cursor: &mut <<DB as Database>::TXMut as DbTxMut>::CursorMut<
tables::TransactionSenders,
>,
) -> Result<(), StageError> {
senders_cursor: &mut CURSOR,
) -> Result<(), StageError>
where
DB: Database,
CURSOR: DbCursorRW<tables::TransactionSenders>,
{
debug!(target: "sync::stages::sender_recovery", ?tx_range, "Recovering senders batch");

// Preallocate channels
Expand Down Expand Up @@ -193,7 +192,8 @@ fn recover_range<DB: Database>(
return match *error {
SenderRecoveryStageError::FailedRecovery(err) => {
// get the block number for the bad transaction
let block_number = tx
let block_number = provider
.tx_ref()
.get::<tables::TransactionBlocks>(err.tx)?
.ok_or(ProviderError::BlockNumberForTransactionIndexNotFound)?;

Expand Down
21 changes: 11 additions & 10 deletions crates/stages/stages/src/stages/utils.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Utils for `stages`.
use reth_config::config::EtlConfig;
use reth_db::BlockNumberList;
use reth_db::{BlockNumberList, Database};
use reth_db_api::{
cursor::{DbCursorRO, DbCursorRW},
models::sharded_key::NUM_OF_INDICES_IN_SHARD,
Expand All @@ -10,6 +10,7 @@ use reth_db_api::{
};
use reth_etl::Collector;
use reth_primitives::BlockNumber;
use reth_provider::DatabaseProviderRW;
use reth_stages_api::StageError;
use std::{collections::HashMap, hash::Hash, ops::RangeBounds};
use tracing::info;
Expand All @@ -34,20 +35,20 @@ const DEFAULT_CACHE_THRESHOLD: u64 = 100_000;
///
/// As a result, the `Collector` will contain entries such as `(Address1.3, [1,2,3])` and
/// `(Address1.300, [100,300])`. The entries may be stored across one or more files.
pub(crate) fn collect_history_indices<TX, CS, H, P>(
tx: &TX,
pub(crate) fn collect_history_indices<DB, CS, H, P>(
provider: &DatabaseProviderRW<DB>,
range: impl RangeBounds<CS::Key>,
sharded_key_factory: impl Fn(P, BlockNumber) -> H::Key,
partial_key_factory: impl Fn((CS::Key, CS::Value)) -> (u64, P),
etl_config: &EtlConfig,
) -> Result<Collector<H::Key, H::Value>, StageError>
where
TX: DbTxMut + DbTx,
DB: Database,
CS: Table,
H: Table<Value = BlockNumberList>,
P: Copy + Eq + Hash,
{
let mut changeset_cursor = tx.cursor_read::<CS>()?;
let mut changeset_cursor = provider.tx_ref().cursor_read::<CS>()?;

let mut collector = Collector::new(etl_config.file_size, etl_config.dir.clone());
let mut cache: HashMap<P, Vec<u64>> = HashMap::new();
Expand All @@ -64,7 +65,7 @@ where
};

// observability
let total_changesets = tx.entries::<CS>()?;
let total_changesets = provider.tx_ref().entries::<CS>()?;
let interval = (total_changesets / 1000).max(1);

let mut flush_counter = 0;
Expand Down Expand Up @@ -101,20 +102,20 @@ where
/// `Address.StorageKey`). It flushes indices to disk when reaching a shard's max length
/// (`NUM_OF_INDICES_IN_SHARD`) or when the partial key changes, ensuring the last previous partial
/// key shard is stored.
pub(crate) fn load_history_indices<TX, H, P>(
tx: &TX,
pub(crate) fn load_history_indices<DB, H, P>(
provider: &DatabaseProviderRW<DB>,
mut collector: Collector<H::Key, H::Value>,
append_only: bool,
sharded_key_factory: impl Clone + Fn(P, u64) -> <H as Table>::Key,
decode_key: impl Fn(Vec<u8>) -> Result<<H as Table>::Key, DatabaseError>,
get_partial: impl Fn(<H as Table>::Key) -> P,
) -> Result<(), StageError>
where
TX: DbTxMut + DbTx,
DB: Database,
H: Table<Value = BlockNumberList>,
P: Copy + Default + Eq,
{
let mut write_cursor = tx.cursor_write::<H>()?;
let mut write_cursor = provider.tx_ref().cursor_write::<H>()?;
let mut current_partial = P::default();
let mut current_list = Vec::<u64>::new();

Expand Down

0 comments on commit 112b233

Please sign in to comment.