From 7abb8c4ba207a7b8ac9c864174792baa209b3019 Mon Sep 17 00:00:00 2001 From: DaniPopes <57450786+DaniPopes@users.noreply.github.com> Date: Thu, 30 Nov 2023 21:36:50 +0100 Subject: [PATCH] fix compact --- crates/primitives/src/trie/nibbles.rs | 71 ++++++++++++++----- crates/storage/codecs/src/lib.rs | 59 ++++++++------- crates/trie/src/trie_cursor/account_cursor.rs | 1 - 3 files changed, 87 insertions(+), 44 deletions(-) diff --git a/crates/primitives/src/trie/nibbles.rs b/crates/primitives/src/trie/nibbles.rs index a0bd77328787..b60102d58554 100644 --- a/crates/primitives/src/trie/nibbles.rs +++ b/crates/primitives/src/trie/nibbles.rs @@ -60,14 +60,12 @@ impl Compact for StoredNibblesSubKey { /// Structure representing a sequence of nibbles. /// -/// A nibble is a 4-bit value, and this structure is used to store the nibble sequence -/// representing the keys in a Merkle Patricia Trie (MPT). -/// Using nibbles simplifies trie operations and enables consistent key representation in the -/// MPT. +/// A nibble is a 4-bit value, and this structure is used to store the nibble sequence representing +/// the keys in a Merkle Patricia Trie (MPT). +/// Using nibbles simplifies trie operations and enables consistent key representation in the MPT. /// -/// The internal representation is a shared heap-allocated vector ([`Bytes`]) that stores one -/// nibble per byte. This means that each byte has its upper 4 bits set to zero and the lower 4 -/// bits representing the nibble value. +/// The internal representation is a [`SmallVec`] that stores one nibble per byte. This means that +/// each byte has its upper 4 bits set to zero and the lower 4 bits representing the nibble value. #[derive( Clone, Default, @@ -117,7 +115,8 @@ impl Compact for Nibbles { where B: bytes::BufMut + AsMut<[u8]>, { - self.to_vec().to_compact(buf) + buf.put_slice(self.as_slice()); + self.len() } fn from_compact(mut buf: &[u8], len: usize) -> (Self, &[u8]) { @@ -295,7 +294,7 @@ impl Nibbles { unsafe fn pack_heap(&self) -> SmallVec<[u8; 32]> { // Collect into a vec directly to avoid the smallvec overhead since we know this is going on // the heap. - let packed_len = self.len() / 2; + let packed_len = (self.len() + 1) / 2; let mut vec = Vec::with_capacity(packed_len); self.pack_to::(vec.as_mut_ptr()); vec.set_len(packed_len); @@ -313,8 +312,9 @@ impl Nibbles { ptr.add(i).write(self.get_byte_unchecked(i * 2)); } if IS_ODD { + debug_assert!(self.len() % 2 != 0); let i = self.len() / 2; - ptr.add(i).write(self.last().unwrap() << 4); + ptr.add(i).write(self.last().unwrap_unchecked() << 4); } } @@ -325,6 +325,8 @@ impl Nibbles { /// `i..i + 1` must be in range. #[inline] unsafe fn get_byte_unchecked(&self, i: usize) -> u8 { + debug_assert!(i % 2 == 0, "index {i} is not a multiple of 2"); + debug_assert!(i + 1 < self.len(), "index {i}..{} out of bounds of {}", i + 1, self.len()); let hi = *self.get_unchecked(i); let lo = *self.get_unchecked(i + 1); (hi << 4) | lo @@ -432,7 +434,7 @@ impl Nibbles { #[inline] #[track_caller] pub fn at(&self, i: usize) -> usize { - self.0[i] as usize + self[i] as usize } /// Returns the last nibble of the current nibble sequence. @@ -483,10 +485,10 @@ impl Nibbles { /// Join two nibbles together. #[inline] pub fn join(&self, b: &Self) -> Self { - let mut hex_data = Vec::with_capacity(self.len() + b.len()); - hex_data.extend_from_slice(self); - hex_data.extend_from_slice(b); - Self::new_unchecked(hex_data) + let mut nibbles = SmallVec::with_capacity(self.len() + b.len()); + nibbles.extend_from_slice(self); + nibbles.extend_from_slice(b); + Self(nibbles) } /// Pushes a nibble to the end of the current nibbles. @@ -530,24 +532,58 @@ mod tests { #[test] fn pack_nibbles() { - for (input, expected) in [ + let tests = [ (&[][..], &[][..]), (&[0xa], &[0xa0]), + (&[0xa, 0x0], &[0xa0]), (&[0xa, 0xb], &[0xab]), (&[0xa, 0xb, 0x2], &[0xab, 0x20]), (&[0xa, 0xb, 0x2, 0x0], &[0xab, 0x20]), (&[0xa, 0xb, 0x2, 0x7], &[0xab, 0x27]), - ] { + ]; + for (input, expected) in tests { + assert!(input.iter().all(|&x| x <= 0xf)); let nibbles = Nibbles::new_unchecked(input); let encoded = nibbles.pack(); assert_eq!(&encoded[..], expected); } } + #[test] + fn slice() { + const RAW: &[u8] = &hex!("05010406040a040203030f010805020b050c04070003070e0909070f010b0a0805020301070c0a0902040b0f000f0006040a04050f020b090701000a0a040b"); + + #[track_caller] + fn test_slice(range: impl RangeBounds, expected: &[u8]) { + let nibbles = Nibbles::new_unchecked(RAW); + let sliced = nibbles.slice(range); + assert_eq!(sliced, Nibbles::new_unchecked(expected)); + assert_eq!(sliced.as_slice(), expected); + } + + test_slice(0..0, &[]); + test_slice(0..1, &[0x05]); + test_slice(1..1, &[]); + test_slice(1..=1, &[0x01]); + test_slice(0..=1, &[0x05, 0x01]); + test_slice(0..2, &[0x05, 0x01]); + + test_slice(..0, &[]); + test_slice(..1, &[0x05]); + test_slice(..=1, &[0x05, 0x01]); + test_slice(..2, &[0x05, 0x01]); + + test_slice(.., RAW); + test_slice(..RAW.len(), RAW); + test_slice(0.., RAW); + test_slice(0..RAW.len(), RAW); + } + proptest! { #[test] fn pack_unpack_roundtrip(input in any::>()) { let nibbles = Nibbles::unpack(&input); + prop_assert!(nibbles.iter().all(|&nibble| nibble <= 0xf)); let packed = nibbles.pack(); prop_assert_eq!(&packed[..], input); } @@ -556,6 +592,7 @@ mod tests { fn encode_path_first_byte(input in any::>()) { prop_assume!(!input.is_empty()); let input = Nibbles::unpack(input); + prop_assert!(input.iter().all(|&nibble| nibble <= 0xf)); let input_is_odd = input.len() % 2 == 1; let compact_leaf = input.encode_path_leaf(true); diff --git a/crates/storage/codecs/src/lib.rs b/crates/storage/codecs/src/lib.rs index aeb0f2c5bece..7a216138dc3f 100644 --- a/crates/storage/codecs/src/lib.rs +++ b/crates/storage/codecs/src/lib.rs @@ -63,7 +63,9 @@ macro_rules! impl_uint_compact { ($($name:tt),+) => { $( impl Compact for $name { - fn to_compact(self, buf: &mut B) -> usize where B: bytes::BufMut + AsMut<[u8]> { + fn to_compact(self, buf: &mut B) -> usize + where B: bytes::BufMut + AsMut<[u8]> + { let leading = self.leading_zeros() as usize / 8; buf.put_slice(&self.to_be_bytes()[leading..]); std::mem::size_of::<$name>() - leading @@ -255,44 +257,49 @@ impl Compact for Bytes { } } -/// Implements the [`Compact`] trait for fixed size hash types like [`B256`]. +impl Compact for [u8; N] { + fn to_compact(self, buf: &mut B) -> usize + where + B: bytes::BufMut + AsMut<[u8]>, + { + buf.put_slice(&self); + N + } + + fn from_compact(mut buf: &[u8], len: usize) -> (Self, &[u8]) { + if len == 0 { + return ([0; N], buf) + } + + let v = buf[..N].try_into().unwrap(); + buf.advance(N); + (v, buf) + } +} + +/// Implements the [`Compact`] trait for fixed size byte array types like [`B256`]. #[macro_export] -macro_rules! impl_hash_compact { +macro_rules! impl_compact_for_bytes { ($($name:tt),+) => { $( impl Compact for $name { - fn to_compact(self, buf: &mut B) -> usize where B: bytes::BufMut + AsMut<[u8]> { - buf.put_slice(self.as_slice()); - std::mem::size_of::<$name>() - } - - fn from_compact(mut buf: &[u8], len: usize) -> (Self,&[u8]) { - if len == 0 { - return ($name::default(), buf) - } - - let v = $name::from_slice( - buf.get(..std::mem::size_of::<$name>()).expect("size not matching"), - ); - buf.advance(std::mem::size_of::<$name>()); - (v, buf) - } - - fn specialized_to_compact(self, buf: &mut B) -> usize + fn to_compact(self, buf: &mut B) -> usize where - B: bytes::BufMut + AsMut<[u8]> { - self.to_compact(buf) + B: bytes::BufMut + AsMut<[u8]> + { + self.0.to_compact(buf) } - fn specialized_from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) { - Self::from_compact(buf, len) + fn from_compact(buf: &[u8], len: usize) -> (Self, &[u8]) { + let (v, buf) = <[u8; std::mem::size_of::<$name>()]>::from_compact(buf, len); + (Self::from(v), buf) } } )+ }; } -impl_hash_compact!(Address, B256, B512, Bloom); +impl_compact_for_bytes!(Address, B256, B512, Bloom); impl Compact for bool { /// `bool` vars go directly to the `StructFlags` and are not written to the buffer. diff --git a/crates/trie/src/trie_cursor/account_cursor.rs b/crates/trie/src/trie_cursor/account_cursor.rs index b98be8bd2ba3..94e7be590146 100644 --- a/crates/trie/src/trie_cursor/account_cursor.rs +++ b/crates/trie/src/trie_cursor/account_cursor.rs @@ -41,7 +41,6 @@ where #[cfg(test)] mod tests { - use super::*; use reth_db::{ cursor::{DbCursorRO, DbCursorRW},