From 58cc3855a6e3fbad337f0c232643ab74a14a04e2 Mon Sep 17 00:00:00 2001 From: Tom Kaitchuck Date: Sat, 3 Aug 2024 10:10:02 -0700 Subject: [PATCH] First pass implementation of VAES Signed-off-by: Tom Kaitchuck --- src/aes_hash.rs | 35 ++++++--------- src/convert.rs | 7 +++ src/operations.rs | 111 +++++++++++++--------------------------------- 3 files changed, 51 insertions(+), 102 deletions(-) diff --git a/src/aes_hash.rs b/src/aes_hash.rs index daf3ae4..6658909 100644 --- a/src/aes_hash.rs +++ b/src/aes_hash.rs @@ -161,31 +161,24 @@ impl Hasher for AHasher { } else { if data.len() > 32 { if data.len() > 64 { - let tail = data.read_last_u128x4(); - let mut current: [u128; 4] = [self.key; 4]; - current[0] = aesenc(current[0], tail[0]); - current[1] = aesdec(current[1], tail[1]); - current[2] = aesenc(current[2], tail[2]); - current[3] = aesdec(current[3], tail[3]); - let mut sum: [u128; 2] = [self.key, !self.key]; - sum[0] = add_by_64s(sum[0].convert(), tail[0].convert()).convert(); - sum[1] = add_by_64s(sum[1].convert(), tail[1].convert()).convert(); - sum[0] = shuffle_and_add(sum[0], tail[2]); - sum[1] = shuffle_and_add(sum[1], tail[3]); + let tail: [U256; 2] = data.read_last_u128x4().convert(); + let mut current: [U256; 2] = [self.key; 4].convert(); + current[0] = vaes::aesenc_vec256(current[0], tail[0]); + current[1] = vaes::aesenc_vec256(current[1], tail[1]); + let mut sum: U256 = [self.key, !self.key]; + sum = vaes::add_by_64s_vec256(sum,tail[0]); + sum = vaes::shuffle_and_add_vec256(sum, tail[1]); while data.len() > 64 { let (blocks, rest) = data.read_u128x4(); - current[0] = aesdec(current[0], blocks[0]); - current[1] = aesdec(current[1], blocks[1]); - current[2] = aesdec(current[2], blocks[2]); - current[3] = aesdec(current[3], blocks[3]); - sum[0] = shuffle_and_add(sum[0], blocks[0]); - sum[1] = shuffle_and_add(sum[1], blocks[1]); - sum[0] = shuffle_and_add(sum[0], blocks[2]); - sum[1] = shuffle_and_add(sum[1], blocks[3]); + let blocks: [U256;2] = blocks.convert(); + current[0] = vaes::aesdec_vec256(current[0], blocks[0]); + current[1] = vaes::aesdec_vec256(current[1], blocks[1]); + sum = vaes::shuffle_and_add_vec256(sum, blocks[0]); + sum = vaes::shuffle_and_add_vec256(sum, blocks[1]); data = rest; } - self.hash_in_2(current[0], current[1]); - self.hash_in_2(current[2], current[3]); + self.hash_in_2(current[0][0], current[0][1]); + self.hash_in_2(current[1][0], current[1][1]); self.hash_in_2(sum[0], sum[1]); } else { //len 33-64 diff --git a/src/convert.rs b/src/convert.rs index 712eae1..d5a91f3 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -2,6 +2,9 @@ pub(crate) trait Convert { fn convert(self) -> To; } +pub type U256 = [u128; 2]; +pub type U512 = [u128; 4]; + macro_rules! convert { ($a:ty, $b:ty) => { impl Convert<$b> for $a { @@ -19,6 +22,10 @@ macro_rules! convert { }; } +convert!([U256; 2], [U512; 1]); +convert!([u128; 4], [U512; 1]); +convert!([u128; 4], [U256; 2]); +convert!([u128; 2], [U256; 1]); convert!([u128; 4], [u64; 8]); convert!([u128; 4], [u32; 16]); convert!([u128; 4], [u16; 32]); diff --git a/src/operations.rs b/src/operations.rs index c0ddd4b..562b750 100644 --- a/src/operations.rs +++ b/src/operations.rs @@ -187,30 +187,33 @@ pub(crate) fn add_in_length(enc: &mut u128, len: u64) { all(target_arch = "aarch64", target_feature = "aes", not(miri)), all(feature = "nightly-arm-aes", target_arch = "arm", target_feature = "aes", not(miri)), ))] -mod vaes { +pub(crate) mod vaes { use super::*; - cfg_if::cfg_if! { - if #[cfg(all( - any(target_arch = "x86", target_arch = "x86_64"), - target_feature = "vaes", not(miri) - ))] { - pub type Vector256 = core::arch::x86_64::__m256i; - // impl FromBytes for Vector256 { - // fn from_bytes(bytes: &[u8]) -> Self { - // unsafe { - // core::arch::x86_64::_mm256_loadu_si256(bytes.as_ptr().cast::()) - // } - // } - // } - - } else { - pub type Vector256 = [u128;2]; + #[inline(always)] + pub(crate) fn aesdec_vec256(value: U256, xor: U256) -> U256 { + cfg_if::cfg_if! { + if #[cfg(all( + any(target_arch = "x86", target_arch = "x86_64"), + target_feature = "vaes", + not(miri) + ))] { + use core::arch::x86_64::*; + unsafe { + transmute!(_mm256_aesdec_epi128(transmute!(value), transmute!(xor))) + } + } + else { + [ + aesdec(value[0], xor[0]), + aesdec(value[1], xor[1]), + ] + } } } #[inline(always)] - pub(crate) fn aesdec_vec256(value: Vector256, xor: Vector256) -> Vector256 { + pub(crate) fn aesenc_vec256(value: U256, xor: U256) -> U256 { cfg_if::cfg_if! { if #[cfg(all( any(target_arch = "x86", target_arch = "x86_64"), @@ -219,20 +222,20 @@ mod vaes { ))] { use core::arch::x86_64::*; unsafe { - _mm256_aesdec_epi128(value, xor) + transmute!(_mm256_aesenc_epi128(transmute!(value), transmute!(xor))) } } else { [ - aesdec(value[0], xor[0]), - aesdec(value[1], xor[1]), + aesenc(value[0], xor[0]), + aesenc(value[1], xor[1]), ] } } } #[inline(always)] - pub(crate) fn add_by_64s_vec256(a: Vector256, b: Vector256) -> Vector256 { + pub(crate) fn add_by_64s_vec256(a: U256, b: U256) -> U256 { cfg_if::cfg_if! { if #[cfg(all( any(target_arch = "x86", target_arch = "x86_64"), @@ -240,7 +243,7 @@ mod vaes { not(miri) ))] { use core::arch::x86_64::*; - unsafe { _mm256_add_epi64(a, b) } + unsafe { transmute!(_mm256_add_epi64(transmute!(a), transmute!(b))) } } else { [ @@ -252,7 +255,7 @@ mod vaes { } #[inline(always)] - pub(crate) fn shuffle_vec256(value: Vector256) -> Vector256 { + pub(crate) fn shuffle_vec256(value: U256) -> U256 { cfg_if::cfg_if! { if #[cfg(all( any(target_arch = "x86", target_arch = "x86_64"), @@ -261,8 +264,8 @@ mod vaes { ))] { unsafe { use core::arch::x86_64::*; - let mask = convert_u128_to_vec256(SHUFFLE_MASK, SHUFFLE_MASK); - _mm256_shuffle_epi8(value, mask) + let mask = transmute!([SHUFFLE_MASK, SHUFFLE_MASK]); + transmute!(_mm256_shuffle_epi8(transmute!(value), mask)) } } else { @@ -275,63 +278,9 @@ mod vaes { } } - pub(crate) fn shuffle_and_add_vec256(base: Vector256, to_add: Vector256) -> Vector256 { + pub(crate) fn shuffle_and_add_vec256(base: U256, to_add: U256) -> U256 { add_by_64s_vec256(shuffle_vec256(base), to_add) } - - // We specialize this routine because sometimes the compiler is not able to - // optimize it properly. - pub(crate) fn read2_vec256(data: &[u8]) -> ([Vector256; 2], &[u8]) { - cfg_if::cfg_if! { - if #[cfg(all( - any(target_arch = "x86", target_arch = "x86_64"), - target_feature = "vaes", - not(miri) - ))] { - use core::arch::x86_64::*; - let (arr, rem) = data.split_at(64); - let arr = unsafe { - [ _mm256_loadu_si256(arr.as_ptr().cast::<__m256i>()), - _mm256_loadu_si256(arr.as_ptr().add(32).cast::<__m256i>()), - ] - }; - (arr, rem) - } - else { - let (arr, slice) = data.read_u128x4(); - (transmute!(arr), slice) - } - } - } - - // We specialize this routine because sometimes the compiler is not able to - // optimize it properly. - pub(crate) fn convert_u128_to_vec256(low: u128, high: u128) -> Vector256 { - transmute!([low, high]) - } - - // We specialize this routine because sometimes the compiler is not able to - // optimize it properly. - pub(crate) fn convert_vec256_to_u128(x: Vector256) -> [u128; 2] { - cfg_if::cfg_if! { - if #[cfg(all( - any(target_arch = "x86", target_arch = "x86_64"), - target_feature = "vaes", - not(miri) - ))] { - use core::arch::x86_64::*; - unsafe { - [ - transmute!(_mm256_extracti128_si256(x, 0)), - transmute!(_mm256_extracti128_si256(x, 1)), - ] - } - } - else { - [x.0, x.1] - } - } - } }