Skip to content

Commit

Permalink
feat(trie): collect branch node hash masks when calculating a proof (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
shekhirin authored Dec 4, 2024
1 parent 337272c commit 27dab59
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 10 deletions.
31 changes: 27 additions & 4 deletions crates/trie/common/src/proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use alloy_rlp::{encode_fixed_size, Decodable, EMPTY_STRING_CODE};
use alloy_trie::{
nodes::TrieNode,
proof::{verify_proof, ProofNodes, ProofVerificationError},
EMPTY_ROOT_HASH,
TrieMask, EMPTY_ROOT_HASH,
};
use itertools::Itertools;
use reth_primitives_traits::Account;
Expand All @@ -23,6 +23,8 @@ use reth_primitives_traits::Account;
pub struct MultiProof {
/// State trie multiproof for requested accounts.
pub account_subtree: ProofNodes,
/// The hash masks of the branch nodes in the account proof.
pub branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
/// Storage trie multiproofs.
pub storages: HashMap<B256, StorageMultiProof>,
}
Expand Down Expand Up @@ -108,11 +110,15 @@ impl MultiProof {
pub fn extend(&mut self, other: Self) {
self.account_subtree.extend_from(other.account_subtree);

self.branch_node_hash_masks.extend(other.branch_node_hash_masks);

for (hashed_address, storage) in other.storages {
match self.storages.entry(hashed_address) {
hash_map::Entry::Occupied(mut entry) => {
debug_assert_eq!(entry.get().root, storage.root);
entry.get_mut().subtree.extend_from(storage.subtree);
let entry = entry.get_mut();
entry.subtree.extend_from(storage.subtree);
entry.branch_node_hash_masks.extend(storage.branch_node_hash_masks);
}
hash_map::Entry::Vacant(entry) => {
entry.insert(storage);
Expand All @@ -129,6 +135,8 @@ pub struct StorageMultiProof {
pub root: B256,
/// Storage multiproof for requested slots.
pub subtree: ProofNodes,
/// The hash masks of the branch nodes in the storage proof.
pub branch_node_hash_masks: HashMap<Nibbles, TrieMask>,
}

impl StorageMultiProof {
Expand All @@ -140,6 +148,7 @@ impl StorageMultiProof {
Nibbles::default(),
Bytes::from([EMPTY_STRING_CODE]),
)]),
branch_node_hash_masks: HashMap::default(),
}
}

Expand Down Expand Up @@ -380,14 +389,28 @@ mod tests {
Nibbles::from_nibbles(vec![0]),
alloy_rlp::encode_fixed_size(&U256::from(42)).to_vec().into(),
);
proof1.storages.insert(addr, StorageMultiProof { root, subtree: subtree1 });
proof1.storages.insert(
addr,
StorageMultiProof {
root,
subtree: subtree1,
branch_node_hash_masks: HashMap::default(),
},
);

let mut subtree2 = ProofNodes::default();
subtree2.insert(
Nibbles::from_nibbles(vec![1]),
alloy_rlp::encode_fixed_size(&U256::from(43)).to_vec().into(),
);
proof2.storages.insert(addr, StorageMultiProof { root, subtree: subtree2 });
proof2.storages.insert(
addr,
StorageMultiProof {
root,
subtree: subtree2,
branch_node_hash_masks: HashMap::default(),
},
);

proof1.extend(proof2);

Expand Down
28 changes: 26 additions & 2 deletions crates/trie/parallel/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ pub struct ParallelProof<Factory> {
view: ConsistentDbView<Factory>,
/// Trie input.
input: Arc<TrieInput>,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
/// Parallel state root metrics.
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics,
Expand All @@ -46,10 +48,17 @@ impl<Factory> ParallelProof<Factory> {
Self {
view,
input,
collect_branch_node_hash_masks: false,
#[cfg(feature = "metrics")]
metrics: ParallelStateRootMetrics::default(),
}
}

/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
self
}
}

impl<Factory> ParallelProof<Factory>
Expand Down Expand Up @@ -125,6 +134,7 @@ where
hashed_address,
)
.with_prefix_set_mut(PrefixSetMut::from(prefix_set.iter().cloned()))
.with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
.storage_multiproof(target_slots)
.map_err(|e| {
ParallelStateRootError::StorageRoot(StorageRootError::Database(
Expand Down Expand Up @@ -158,7 +168,9 @@ where

// Create a hash builder to rebuild the root node since it is not available in the database.
let retainer: ProofRetainer = targets.keys().map(Nibbles::unpack).collect();
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);

let mut storages = HashMap::default();
let mut account_rlp = Vec::with_capacity(TRIE_ACCOUNT_RLP_MAX_SIZE);
Expand Down Expand Up @@ -222,7 +234,19 @@ where
#[cfg(feature = "metrics")]
self.metrics.record_state_trie(tracker.finish());

Ok(MultiProof { account_subtree: hash_builder.take_proof_nodes(), storages })
let account_subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
} else {
HashMap::default()
};

Ok(MultiProof { account_subtree, branch_node_hash_masks, storages })
}
}

Expand Down
59 changes: 55 additions & 4 deletions crates/trie/trie/src/proof/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub struct Proof<T, H> {
hashed_cursor_factory: H,
/// A set of prefix sets that have changes.
prefix_sets: TriePrefixSetsMut,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
}

impl<T, H> Proof<T, H> {
Expand All @@ -42,6 +44,7 @@ impl<T, H> Proof<T, H> {
trie_cursor_factory: t,
hashed_cursor_factory: h,
prefix_sets: TriePrefixSetsMut::default(),
collect_branch_node_hash_masks: false,
}
}

Expand All @@ -51,6 +54,7 @@ impl<T, H> Proof<T, H> {
trie_cursor_factory,
hashed_cursor_factory: self.hashed_cursor_factory,
prefix_sets: self.prefix_sets,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
}
}

Expand All @@ -60,6 +64,7 @@ impl<T, H> Proof<T, H> {
trie_cursor_factory: self.trie_cursor_factory,
hashed_cursor_factory,
prefix_sets: self.prefix_sets,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
}
}

Expand All @@ -68,6 +73,12 @@ impl<T, H> Proof<T, H> {
self.prefix_sets = prefix_sets;
self
}

/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
self
}
}

impl<T, H> Proof<T, H>
Expand Down Expand Up @@ -104,7 +115,9 @@ where

// Create a hash builder to rebuild the root node since it is not available in the database.
let retainer = targets.keys().map(Nibbles::unpack).collect();
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);

// Initialize all storage multiproofs as empty.
// Storage multiproofs for non empty tries will be overwritten if necessary.
Expand All @@ -131,6 +144,7 @@ where
hashed_address,
)
.with_prefix_set_mut(storage_prefix_set)
.with_branch_node_hash_masks(self.collect_branch_node_hash_masks)
.storage_multiproof(proof_targets.unwrap_or_default())?;

// Encode account
Expand All @@ -149,7 +163,19 @@ where
}
}
let _ = hash_builder.root();
Ok(MultiProof { account_subtree: hash_builder.take_proof_nodes(), storages })
let account_subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
} else {
HashMap::default()
};

Ok(MultiProof { account_subtree, branch_node_hash_masks, storages })
}
}

Expand All @@ -164,6 +190,8 @@ pub struct StorageProof<T, H> {
hashed_address: B256,
/// The set of storage slot prefixes that have changed.
prefix_set: PrefixSetMut,
/// Flag indicating whether to include branch node hash masks in the proof.
collect_branch_node_hash_masks: bool,
}

impl<T, H> StorageProof<T, H> {
Expand All @@ -179,6 +207,7 @@ impl<T, H> StorageProof<T, H> {
hashed_cursor_factory: h,
hashed_address,
prefix_set: PrefixSetMut::default(),
collect_branch_node_hash_masks: false,
}
}

Expand All @@ -189,6 +218,7 @@ impl<T, H> StorageProof<T, H> {
hashed_cursor_factory: self.hashed_cursor_factory,
hashed_address: self.hashed_address,
prefix_set: self.prefix_set,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
}
}

Expand All @@ -199,6 +229,7 @@ impl<T, H> StorageProof<T, H> {
hashed_cursor_factory,
hashed_address: self.hashed_address,
prefix_set: self.prefix_set,
collect_branch_node_hash_masks: self.collect_branch_node_hash_masks,
}
}

Expand All @@ -207,6 +238,12 @@ impl<T, H> StorageProof<T, H> {
self.prefix_set = prefix_set;
self
}

/// Set the flag indicating whether to include branch node hash masks in the proof.
pub const fn with_branch_node_hash_masks(mut self, branch_node_hash_masks: bool) -> Self {
self.collect_branch_node_hash_masks = branch_node_hash_masks;
self
}
}

impl<T, H> StorageProof<T, H>
Expand Down Expand Up @@ -243,7 +280,9 @@ where
let walker = TrieWalker::new(trie_cursor, self.prefix_set.freeze());

let retainer = ProofRetainer::from_iter(target_nibbles);
let mut hash_builder = HashBuilder::default().with_proof_retainer(retainer);
let mut hash_builder = HashBuilder::default()
.with_proof_retainer(retainer)
.with_updates(self.collect_branch_node_hash_masks);
let mut storage_node_iter = TrieNodeIter::new(walker, hashed_storage_cursor);
while let Some(node) = storage_node_iter.try_next()? {
match node {
Expand All @@ -260,6 +299,18 @@ where
}

let root = hash_builder.root();
Ok(StorageMultiProof { root, subtree: hash_builder.take_proof_nodes() })
let subtree = hash_builder.take_proof_nodes();
let branch_node_hash_masks = if self.collect_branch_node_hash_masks {
hash_builder
.updated_branch_nodes
.unwrap_or_default()
.into_iter()
.map(|(path, node)| (path, node.hash_mask))
.collect()
} else {
HashMap::default()
};

Ok(StorageMultiProof { root, subtree, branch_node_hash_masks })
}
}

0 comments on commit 27dab59

Please sign in to comment.