From 7239149f4dccadb95bf530c6c31867e8510b7ea7 Mon Sep 17 00:00:00 2001 From: Robert Debug Date: Tue, 25 Aug 2020 21:29:37 +0200 Subject: [PATCH] make scalar configurable --- rust/examples/mobile-client.rs | 18 ++++++++----- .../xaynet-client/src/mobile_client/client.rs | 12 +++------ .../src/mobile_client/participant/awaiting.rs | 18 +++++++------ .../src/mobile_client/participant/mod.rs | 14 ++++++++--- .../src/mobile_client/participant/sum2.rs | 25 +++++++++++-------- .../src/mobile_client/participant/update.rs | 9 ++++--- 6 files changed, 56 insertions(+), 40 deletions(-) diff --git a/rust/examples/mobile-client.rs b/rust/examples/mobile-client.rs index a7eb538ea..a424e4251 100644 --- a/rust/examples/mobile-client.rs +++ b/rust/examples/mobile-client.rs @@ -4,7 +4,10 @@ extern crate tracing; use std::io::{stdin, stdout, Read, Write}; use structopt::StructOpt; use tracing_subscriber::*; -use xaynet_client::mobile_client::{participant::ParticipantSettings, MobileClient}; +use xaynet_client::mobile_client::{ + participant::{AggregationConfig, ParticipantSettings}, + MobileClient, +}; use xaynet_core::mask::{ BoundType, DataType, @@ -41,11 +44,14 @@ fn get_participant_settings() -> ParticipantSettings { let secret_key = MobileClient::create_participant_secret_key(); ParticipantSettings { secret_key, - mask_config: MaskConfig { - group_type: GroupType::Prime, - data_type: DataType::F32, - bound_type: BoundType::B0, - model_type: ModelType::M3, + aggregation_config: AggregationConfig { + mask: MaskConfig { + group_type: GroupType::Prime, + data_type: DataType::F32, + bound_type: BoundType::B0, + model_type: ModelType::M3, + }, + scalar: 1_f64, }, } } diff --git a/rust/xaynet-client/src/mobile_client/client.rs b/rust/xaynet-client/src/mobile_client/client.rs index 63644aac8..89545245a 100644 --- a/rust/xaynet-client/src/mobile_client/client.rs +++ b/rust/xaynet-client/src/mobile_client/client.rs @@ -170,21 +170,15 @@ impl ClientState { .await .ok_or(ClientError::TooEarly("local model"))?; - debug!("setting model scalar"); - let scalar = 1_f64; // TODO parametrise this! - debug!("polling for sum dict"); let sums = api .get_sums() .await? .ok_or(ClientError::TooEarly("sum dict"))?; - let upd_msg = self.participant.compose_update_message( - self.round_params.pk, - &sums, - scalar, - local_model, - ); + let upd_msg = + self.participant + .compose_update_message(self.round_params.pk, &sums, local_model); let sealed_msg = self .participant .seal_message(&self.round_params.pk, &upd_msg); diff --git a/rust/xaynet-client/src/mobile_client/participant/awaiting.rs b/rust/xaynet-client/src/mobile_client/participant/awaiting.rs index 95172df8c..58899ccfc 100644 --- a/rust/xaynet-client/src/mobile_client/participant/awaiting.rs +++ b/rust/xaynet-client/src/mobile_client/participant/awaiting.rs @@ -48,6 +48,7 @@ impl Participant { #[cfg(test)] mod tests { use super::*; + use crate::mobile_client::participant::AggregationConfig; use sodiumoxide::randombytes::randombytes; use xaynet_core::{ crypto::{ByteObject, SigningKeyPair}, @@ -59,16 +60,19 @@ mod tests { fn participant_state() -> ParticipantState { sodiumoxide::init().unwrap(); - let mask_config = MaskConfig { - group_type: GroupType::Prime, - data_type: DataType::F32, - bound_type: BoundType::B0, - model_type: ModelType::M3, - }; + let aggregation_config = AggregationConfig { + mask: MaskConfig { + group_type: GroupType::Prime, + data_type: DataType::F32, + bound_type: BoundType::B0, + model_type: ModelType::M3, + }, + scalar: 1_f64, + }; ParticipantState { keys: SigningKeyPair::generate(), - mask_config, + aggregation_config, } } diff --git a/rust/xaynet-client/src/mobile_client/participant/mod.rs b/rust/xaynet-client/src/mobile_client/participant/mod.rs index 695751b3c..f175063e1 100644 --- a/rust/xaynet-client/src/mobile_client/participant/mod.rs +++ b/rust/xaynet-client/src/mobile_client/participant/mod.rs @@ -19,25 +19,31 @@ pub mod update; pub use self::{awaiting::Awaiting, sum::Sum, sum2::Sum2, update::Update}; +#[derive(Serialize, Deserialize)] +pub struct AggregationConfig { + pub mask: MaskConfig, + pub scalar: f64, +} + #[derive(Serialize, Deserialize)] pub struct ParticipantState { // credentials pub keys: SigningKeyPair, // Mask config - pub mask_config: MaskConfig, + pub aggregation_config: AggregationConfig, } #[derive(Serialize, Deserialize)] pub struct ParticipantSettings { pub secret_key: ParticipantSecretKey, - pub mask_config: MaskConfig, + pub aggregation_config: AggregationConfig, } impl From for ParticipantState { fn from( ParticipantSettings { secret_key, - mask_config, + aggregation_config, }: ParticipantSettings, ) -> ParticipantState { ParticipantState { @@ -45,7 +51,7 @@ impl From for ParticipantState { public: secret_key.public_key(), secret: secret_key, }, - mask_config, + aggregation_config, } } } diff --git a/rust/xaynet-client/src/mobile_client/participant/sum2.rs b/rust/xaynet-client/src/mobile_client/participant/sum2.rs index f75c57be6..641f9a142 100644 --- a/rust/xaynet-client/src/mobile_client/participant/sum2.rs +++ b/rust/xaynet-client/src/mobile_client/participant/sum2.rs @@ -84,10 +84,11 @@ impl Participant { return Err(PetError::InvalidMask); } - let mut model_mask_agg = Aggregation::new(self.state.mask_config, mask_len); - let mut scalar_mask_agg = Aggregation::new(self.state.mask_config, 1); + let mut model_mask_agg = Aggregation::new(self.state.aggregation_config.mask, mask_len); + let mut scalar_mask_agg = Aggregation::new(self.state.aggregation_config.mask, 1); for seed in mask_seeds.into_iter() { - let (model_mask, scalar_mask) = seed.derive_mask(mask_len, self.state.mask_config); + let (model_mask, scalar_mask) = + seed.derive_mask(mask_len, self.state.aggregation_config.mask); model_mask_agg .validate_aggregation(&model_mask) @@ -106,6 +107,7 @@ impl Participant { #[cfg(test)] mod tests { use super::*; + use crate::mobile_client::participant::AggregationConfig; use sodiumoxide::randombytes::{randombytes, randombytes_uniform}; use std::{collections::HashSet, iter}; use xaynet_core::{ @@ -117,16 +119,19 @@ mod tests { fn participant_state() -> ParticipantState { sodiumoxide::init().unwrap(); - let mask_config = MaskConfig { - group_type: GroupType::Prime, - data_type: DataType::F32, - bound_type: BoundType::B0, - model_type: ModelType::M3, - }; + let aggregation_config = AggregationConfig { + mask: MaskConfig { + group_type: GroupType::Prime, + data_type: DataType::F32, + bound_type: BoundType::B0, + model_type: ModelType::M3, + }, + scalar: 1_f64, + }; ParticipantState { keys: SigningKeyPair::generate(), - mask_config, + aggregation_config, } } diff --git a/rust/xaynet-client/src/mobile_client/participant/update.rs b/rust/xaynet-client/src/mobile_client/participant/update.rs index 292ce3150..f85b4a054 100644 --- a/rust/xaynet-client/src/mobile_client/participant/update.rs +++ b/rust/xaynet-client/src/mobile_client/participant/update.rs @@ -34,10 +34,10 @@ impl Participant { &self, coordinator_pk: CoordinatorPublicKey, sum_dict: &SumDict, - scalar: f64, + local_model: Model, ) -> Message { - let (mask_seed, masked_model, masked_scalar) = self.mask_model(scalar, local_model); + let (mask_seed, masked_model, masked_scalar) = self.mask_model(local_model); let local_seed_dict = Self::create_local_seed_dict(sum_dict, &mask_seed); Message { @@ -56,8 +56,9 @@ impl Participant { } /// Generate a mask seed and mask a local model. - fn mask_model(&self, scalar: f64, local_model: Model) -> (MaskSeed, MaskObject, MaskObject) { - Masker::new(self.state.mask_config).mask(scalar, local_model) + fn mask_model(&self, local_model: Model) -> (MaskSeed, MaskObject, MaskObject) { + Masker::new(self.state.aggregation_config.mask) + .mask(self.state.aggregation_config.scalar, local_model) } // Create a local seed dictionary from a sum dictionary.