Skip to content
This repository has been archived by the owner on Jan 22, 2025. It is now read-only.

Feat/no threads feat for zk sdk #35514

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions zk-token-sdk/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@ zeroize = { workspace = true, features = ["zeroize_derive"] }

[lib]
crate-type = ["cdylib", "rlib"]

[features]
default = ["enable-threaded"]
enable-threaded = []
94 changes: 64 additions & 30 deletions zk-token-sdk/src/encryption/discrete_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#![cfg(not(target_os = "solana"))]

#[cfg(feature = "enable-threaded")]
use std::thread;
use {
crate::RISTRETTO_POINT_LEN,
curve25519_dalek::{
Expand All @@ -26,7 +28,7 @@ use {
},
itertools::Itertools,
serde::{Deserialize, Serialize},
std::{collections::HashMap, thread},
std::collections::HashMap,
thiserror::Error,
};

Expand Down Expand Up @@ -137,41 +139,53 @@ impl DiscreteLog {

Ok(())
}

/// Solves the discrete log problem under the assumption that the solution
/// is a positive 32-bit number.
pub fn decode_u32(self) -> Option<u64> {
let mut starting_point = self.target;
let handles = (0..self.num_threads)
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), self.num_threads as u64),
);

let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});

starting_point -= G;
handle
})
.collect::<Vec<_>>();

let mut solution = None;
for handle in handles {
let discrete_log = handle.join().unwrap();
if discrete_log.is_some() {
solution = discrete_log;
#[cfg(not(feature = "enable-threaded"))]
{
let ristretto_iterator =
RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1_u64));
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
}

#[cfg(feature = "enable-threaded")]
{
let mut starting_point = self.target;
let handles = (0..self.num_threads)
.map(|i| {
let ristretto_iterator = RistrettoIterator::new(
(starting_point, i as u64),
(-(&self.step_point), self.num_threads as u64),
);

let handle = thread::spawn(move || {
Self::decode_range(
ristretto_iterator,
self.range_bound,
self.compression_batch_size,
)
});

starting_point -= G;
handle
})
.collect::<Vec<_>>();

let mut solution = None;
for handle in handles {
let discrete_log = handle.join().unwrap();
if discrete_log.is_some() {
solution = discrete_log;
}
}
solution
}
solution
}

fn decode_range(
ristretto_iterator: RistrettoIterator,
range_bound: usize,
Expand Down Expand Up @@ -258,6 +272,7 @@ mod tests {
}

#[test]
#[cfg(feature = "enable-threaded")]
fn test_decode_correctness() {
// general case
let amount: u64 = 4294967295;
Expand All @@ -271,10 +286,29 @@ mod tests {

assert_eq!(amount, decoded.unwrap());

println!("no threads discrete log computation secs: {computation_secs:?} sec");
}

#[test]
#[cfg(not(feature = "enable-threaded"))]
fn test_decode_correctness_no_threads_feat() {
// general case
let amount: u64 = 4294967295;

let instance = DiscreteLog::new(G, Scalar::from(amount) * G);

// Very informal measurements for now
let start_computation = Instant::now();
let decoded = instance.decode_u32();
let computation_secs = start_computation.elapsed().as_secs_f64();

assert_eq!(amount, decoded.unwrap());

println!("single thread discrete log computation secs: {computation_secs:?} sec");
}

#[test]
#[cfg(feature = "enable-threaded")]
fn test_decode_correctness_threaded() {
// general case
let amount: u64 = 55;
Expand Down
1 change: 1 addition & 0 deletions zk-token-sdk/src/encryption/elgamal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,7 @@ mod tests {
}

#[test]
#[cfg(feature = "enable-threaded")]
fn test_encrypt_decrypt_correctness_multithreaded() {
let ElGamalKeypair { public, secret } = ElGamalKeypair::new_rand();
let amount: u32 = 57;
Expand Down
Loading