Skip to content

Commit

Permalink
First pass implementation of VAES
Browse files Browse the repository at this point in the history
Signed-off-by: Tom Kaitchuck <[email protected]>
  • Loading branch information
tkaitchuck committed Aug 3, 2024
1 parent a38d3ee commit 58cc385
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 102 deletions.
35 changes: 14 additions & 21 deletions src/aes_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,31 +161,24 @@ impl Hasher for AHasher {
} else {
if data.len() > 32 {
if data.len() > 64 {
let tail = data.read_last_u128x4();
let mut current: [u128; 4] = [self.key; 4];
current[0] = aesenc(current[0], tail[0]);
current[1] = aesdec(current[1], tail[1]);
current[2] = aesenc(current[2], tail[2]);
current[3] = aesdec(current[3], tail[3]);
let mut sum: [u128; 2] = [self.key, !self.key];
sum[0] = add_by_64s(sum[0].convert(), tail[0].convert()).convert();
sum[1] = add_by_64s(sum[1].convert(), tail[1].convert()).convert();
sum[0] = shuffle_and_add(sum[0], tail[2]);
sum[1] = shuffle_and_add(sum[1], tail[3]);
let tail: [U256; 2] = data.read_last_u128x4().convert();
let mut current: [U256; 2] = [self.key; 4].convert();
current[0] = vaes::aesenc_vec256(current[0], tail[0]);
current[1] = vaes::aesenc_vec256(current[1], tail[1]);
let mut sum: U256 = [self.key, !self.key];
sum = vaes::add_by_64s_vec256(sum,tail[0]);
sum = vaes::shuffle_and_add_vec256(sum, tail[1]);
while data.len() > 64 {
let (blocks, rest) = data.read_u128x4();
current[0] = aesdec(current[0], blocks[0]);
current[1] = aesdec(current[1], blocks[1]);
current[2] = aesdec(current[2], blocks[2]);
current[3] = aesdec(current[3], blocks[3]);
sum[0] = shuffle_and_add(sum[0], blocks[0]);
sum[1] = shuffle_and_add(sum[1], blocks[1]);
sum[0] = shuffle_and_add(sum[0], blocks[2]);
sum[1] = shuffle_and_add(sum[1], blocks[3]);
let blocks: [U256;2] = blocks.convert();
current[0] = vaes::aesdec_vec256(current[0], blocks[0]);
current[1] = vaes::aesdec_vec256(current[1], blocks[1]);
sum = vaes::shuffle_and_add_vec256(sum, blocks[0]);
sum = vaes::shuffle_and_add_vec256(sum, blocks[1]);
data = rest;
}
self.hash_in_2(current[0], current[1]);
self.hash_in_2(current[2], current[3]);
self.hash_in_2(current[0][0], current[0][1]);
self.hash_in_2(current[1][0], current[1][1]);
self.hash_in_2(sum[0], sum[1]);
} else {
//len 33-64
Expand Down
7 changes: 7 additions & 0 deletions src/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ pub(crate) trait Convert<To> {
fn convert(self) -> To;
}

pub type U256 = [u128; 2];
pub type U512 = [u128; 4];

macro_rules! convert {
($a:ty, $b:ty) => {
impl Convert<$b> for $a {
Expand All @@ -19,6 +22,10 @@ macro_rules! convert {
};
}

convert!([U256; 2], [U512; 1]);
convert!([u128; 4], [U512; 1]);
convert!([u128; 4], [U256; 2]);
convert!([u128; 2], [U256; 1]);
convert!([u128; 4], [u64; 8]);
convert!([u128; 4], [u32; 16]);
convert!([u128; 4], [u16; 32]);
Expand Down
111 changes: 30 additions & 81 deletions src/operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,30 +187,33 @@ pub(crate) fn add_in_length(enc: &mut u128, len: u64) {
all(target_arch = "aarch64", target_feature = "aes", not(miri)),
all(feature = "nightly-arm-aes", target_arch = "arm", target_feature = "aes", not(miri)),
))]
mod vaes {
pub(crate) mod vaes {
use super::*;
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes", not(miri)
))] {
pub type Vector256 = core::arch::x86_64::__m256i;

// impl FromBytes for Vector256 {
// fn from_bytes(bytes: &[u8]) -> Self {
// unsafe {
// core::arch::x86_64::_mm256_loadu_si256(bytes.as_ptr().cast::<core::arch::x86_64::__m256i>())
// }
// }
// }

} else {
pub type Vector256 = [u128;2];
#[inline(always)]
pub(crate) fn aesdec_vec256(value: U256, xor: U256) -> U256 {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
not(miri)
))] {
use core::arch::x86_64::*;
unsafe {
transmute!(_mm256_aesdec_epi128(transmute!(value), transmute!(xor)))
}
}
else {
[
aesdec(value[0], xor[0]),
aesdec(value[1], xor[1]),
]
}
}
}

#[inline(always)]
pub(crate) fn aesdec_vec256(value: Vector256, xor: Vector256) -> Vector256 {
pub(crate) fn aesenc_vec256(value: U256, xor: U256) -> U256 {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
Expand All @@ -219,28 +222,28 @@ mod vaes {
))] {
use core::arch::x86_64::*;
unsafe {
_mm256_aesdec_epi128(value, xor)
transmute!(_mm256_aesenc_epi128(transmute!(value), transmute!(xor)))
}
}
else {
[
aesdec(value[0], xor[0]),
aesdec(value[1], xor[1]),
aesenc(value[0], xor[0]),
aesenc(value[1], xor[1]),
]
}
}
}

#[inline(always)]
pub(crate) fn add_by_64s_vec256(a: Vector256, b: Vector256) -> Vector256 {
pub(crate) fn add_by_64s_vec256(a: U256, b: U256) -> U256 {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
not(miri)
))] {
use core::arch::x86_64::*;
unsafe { _mm256_add_epi64(a, b) }
unsafe { transmute!(_mm256_add_epi64(transmute!(a), transmute!(b))) }
}
else {
[
Expand All @@ -252,7 +255,7 @@ mod vaes {
}

#[inline(always)]
pub(crate) fn shuffle_vec256(value: Vector256) -> Vector256 {
pub(crate) fn shuffle_vec256(value: U256) -> U256 {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
Expand All @@ -261,8 +264,8 @@ mod vaes {
))] {
unsafe {
use core::arch::x86_64::*;
let mask = convert_u128_to_vec256(SHUFFLE_MASK, SHUFFLE_MASK);
_mm256_shuffle_epi8(value, mask)
let mask = transmute!([SHUFFLE_MASK, SHUFFLE_MASK]);
transmute!(_mm256_shuffle_epi8(transmute!(value), mask))
}
}
else {
Expand All @@ -275,63 +278,9 @@ mod vaes {
}
}

pub(crate) fn shuffle_and_add_vec256(base: Vector256, to_add: Vector256) -> Vector256 {
pub(crate) fn shuffle_and_add_vec256(base: U256, to_add: U256) -> U256 {
add_by_64s_vec256(shuffle_vec256(base), to_add)
}

// We specialize this routine because sometimes the compiler is not able to
// optimize it properly.
pub(crate) fn read2_vec256(data: &[u8]) -> ([Vector256; 2], &[u8]) {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
not(miri)
))] {
use core::arch::x86_64::*;
let (arr, rem) = data.split_at(64);
let arr = unsafe {
[ _mm256_loadu_si256(arr.as_ptr().cast::<__m256i>()),
_mm256_loadu_si256(arr.as_ptr().add(32).cast::<__m256i>()),
]
};
(arr, rem)
}
else {
let (arr, slice) = data.read_u128x4();
(transmute!(arr), slice)
}
}
}

// We specialize this routine because sometimes the compiler is not able to
// optimize it properly.
pub(crate) fn convert_u128_to_vec256(low: u128, high: u128) -> Vector256 {
transmute!([low, high])
}

// We specialize this routine because sometimes the compiler is not able to
// optimize it properly.
pub(crate) fn convert_vec256_to_u128(x: Vector256) -> [u128; 2] {
cfg_if::cfg_if! {
if #[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
target_feature = "vaes",
not(miri)
))] {
use core::arch::x86_64::*;
unsafe {
[
transmute!(_mm256_extracti128_si256(x, 0)),
transmute!(_mm256_extracti128_si256(x, 1)),
]
}
}
else {
[x.0, x.1]
}
}
}
}


Expand Down

0 comments on commit 58cc385

Please sign in to comment.