Skip to content

Commit

Permalink
state for subscribed topics
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-maron committed Feb 7, 2024
1 parent 5620289 commit 7cd3ac0
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 51 deletions.
54 changes: 42 additions & 12 deletions broker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

mod state;

use std::{marker::PhantomData, sync::Arc, time::Duration};
use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration};

use jf_primitives::signatures::SignatureScheme as JfSignatureScheme;
// TODO: figure out if we should use Tokio's here
Expand All @@ -17,11 +17,13 @@ use proto::{
},
crypto::Serializable,
error::{Error, Result},
message::Topic,
parse_socket_address,
redis::{self, BrokerIdentifier},
verify_broker,
};
use tokio::{select, spawn, time::sleep};
use state::{ConnectionLookup, ConnectionWithQueue};
use tokio::{select, spawn, sync::RwLock, time::{sleep, Instant}};
use tracing::{error, warn};

/// The broker's configuration. We need this when we create a new one.
Expand Down Expand Up @@ -82,6 +84,10 @@ struct Inner<
/// authentication phase.
pub signing_key: BrokerSignatureScheme::SigningKey,

pub broker_connections: RwLock<ConnectionLookup<UserSignatureScheme, BrokerProtocolType>>,

pub user_connections: RwLock<ConnectionLookup<UserSignatureScheme, UserProtocolType>>,

// connected_keys: LoggedSet<UserSignatureScheme::VerificationKey>,
/// The `PhantomData` that we need to be generic over protocol types.
pd: PhantomData<(UserProtocolType, BrokerProtocolType, UserSignatureScheme)>,
Expand Down Expand Up @@ -200,6 +206,8 @@ where
identifier,
verification_key,
signing_key,
broker_connections: RwLock::from(ConnectionLookup::default()),
user_connections: RwLock::from(ConnectionLookup::default()),
pd: PhantomData,
}),
user_listener,
Expand Down Expand Up @@ -229,6 +237,17 @@ where
verify_broker!(connection, inner);
authenticate_with_broker!(connection, inner)
};

// Create a new queued connection
let connection = Arc::from(
ConnectionWithQueue::<BrokerProtocolType>::from_connection_and_params(
connection,
Duration::from_millis(50),
5000,
),
);

// Add the connection to our map
}

/// This function handles a user (public) connection. We take the following steps:
Expand All @@ -241,7 +260,7 @@ where
connection: UserProtocolType::Connection,
) {
// Verify (authenticate) the connection
let Ok(verification_key) =
let Ok((verification_key, topics)) =
BrokerAuth::<UserSignatureScheme, UserProtocolType>::verify_user(
&connection,
&inner.identifier,
Expand All @@ -252,18 +271,29 @@ where
return;
};

println!("meow");
// Create a new queued connection
let connection = Arc::from(
ConnectionWithQueue::<UserProtocolType>::from_connection_and_params(
connection,
Duration::from_millis(50),
5000,
),
);

// // Create a new queued connection
// let connection_with_queue = ConnectionWithQueue{
// connection: connection,
// last_sent: SystemTime::now(),
// buffer: Arc::default(),
println!("user subbed to {:?}", topics);

// }
// Add the connection to our maps
inner
.user_connections
.write()
.await
.subscribe_connection_to_broadcast(connection.clone(), topics);

// // Add to our direct map
// inner.user_to_connection.write().await.insert(verification_key, Either::Left());
inner
.user_connections
.write()
.await
.subscribe_connection_to_direct(connection.clone(), verification_key)
}

/// The main loop for a broker.
Expand Down
192 changes: 192 additions & 0 deletions broker/src/state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use std::{
collections::{HashMap, HashSet},
hash::Hash,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};

use jf_primitives::signatures::SignatureScheme as JfSignatureScheme;
// TODO: maybe use Tokio's RwLock
use parking_lot::RwLock;
use proto::{
connection::protocols::{Connection, Protocol},
crypto::Serializable,
message::Topic,
};
use tokio::{spawn, sync::Mutex, time::Instant};

pub struct ConnectionLookup<
SignatureScheme: JfSignatureScheme<PublicParameter = (), MessageUnit = u8>,
ProtocolType: Protocol,
> where
SignatureScheme::VerificationKey: Serializable,
{
direct_message_lookup:
HashMap<SignatureScheme::VerificationKey, Arc<ConnectionWithQueue<ProtocolType>>>,
broadcast_message_lookup: HashMap<Topic, HashSet<Arc<ConnectionWithQueue<ProtocolType>>>>,
inverse_broadcast_message_lookup:
HashMap<Arc<ConnectionWithQueue<ProtocolType>>, HashSet<Topic>>,
}

impl<
SignatureScheme: JfSignatureScheme<PublicParameter = (), MessageUnit = u8>,
ProtocolType: Protocol,
> Default for ConnectionLookup<SignatureScheme, ProtocolType>
where
SignatureScheme::Signature: Serializable,
SignatureScheme::VerificationKey: Serializable,
SignatureScheme::SigningKey: Serializable,
{
fn default() -> Self {
Self {
direct_message_lookup: HashMap::default(),
broadcast_message_lookup: HashMap::default(),
inverse_broadcast_message_lookup: HashMap::default(),
}
}
}

impl<
SignatureScheme: JfSignatureScheme<PublicParameter = (), MessageUnit = u8>,
ProtocolType: Protocol,
> ConnectionLookup<SignatureScheme, ProtocolType>
where
SignatureScheme::VerificationKey: Serializable,
{
pub fn subscribe_connection_to_broadcast(
&mut self,
connection: Arc<ConnectionWithQueue<ProtocolType>>,
topics: Vec<Topic>,
) {
//topic -> [connection]
for topic in topics.clone() {
self.broadcast_message_lookup
.entry(topic)
.or_default()
.insert(connection.clone());
}
//connection -> [topic]
self.inverse_broadcast_message_lookup
.entry(connection)
.or_default()
.extend(topics);
}

pub fn unsubscribe_connection_from_broadcast(
&mut self,
connection: Arc<ConnectionWithQueue<ProtocolType>>,
topics: Vec<Topic>,
) {
//topic -> [connection]
for topic in topics.clone() {
// remove connection from topic, and remove topic if empty
if let Some(connections) = self.broadcast_message_lookup.get_mut(&topic) {
connections.remove(&connection);
}
}

//key -> [topic]
if let Some(connection_topics) = self.inverse_broadcast_message_lookup.get_mut(&connection)
{
for topic in topics {
connection_topics.remove(&topic);
}
}
}

pub fn subscribe_connection_to_direct(
&mut self,
connection: Arc<ConnectionWithQueue<ProtocolType>>,
key: SignatureScheme::VerificationKey,
) {
self.direct_message_lookup.insert(key, connection);
}

pub fn unsubscribe_connection_from_direct(&mut self, key: SignatureScheme::VerificationKey) {
self.direct_message_lookup.remove(&key);
}
}

pub struct ConnectionWithQueue<ProtocolType: Protocol> {
queue: Mutex<Vec<Arc<Vec<u8>>>>,
connection: ProtocolType::Connection,

current_size: AtomicU64,
last_sent: RwLock<Instant>,

min_duration: Duration,
min_size: u64,
}

impl<ProtocolType: Protocol> PartialEq for ConnectionWithQueue<ProtocolType> {
fn eq(&self, other: &Self) -> bool {
self.connection == other.connection
}
}

impl<ProtocolType: Protocol> Eq for ConnectionWithQueue<ProtocolType> {
fn assert_receiver_is_total_eq(&self) {}
}

impl<ProtocolType: Protocol> Hash for ConnectionWithQueue<ProtocolType> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.connection.hash(state);
}

/// This just calls `hash` on each item in the slice.
fn hash_slice<H: std::hash::Hasher>(data: &[Self], state: &mut H)
where
Self: Sized,
{
data.iter().for_each(|item| item.hash(state));
}
}

impl<ProtocolType: Protocol> ConnectionWithQueue<ProtocolType> {
pub fn from_connection_and_params(
connection: ProtocolType::Connection,
min_duration: Duration,
min_size: u64,
) -> Self {
Self {
queue: Mutex::default(),
connection,
current_size: AtomicU64::default(),
last_sent: RwLock::from(Instant::now()),
min_duration,
min_size,
}
}

pub async fn add_or_queue_message(&self, message: Arc<Vec<u8>>) {
// Push the reference to the message
let message_length = message.len() as u64;
let mut queue_guard = self.queue.lock().await;
queue_guard.push(message);

// Update our size
let before_send_size = self
.current_size
.fetch_add(message_length, Ordering::Relaxed);

// Bounds check to see if we should send
if (before_send_size + message_length) >= self.min_size
|| self.last_sent.read().elapsed() >= self.min_duration
{
// Move messages out
// TODO: VEC WITH CAPACITY HERE
let messages = std::mem::replace(&mut *queue_guard, Vec::new());

// Spawn a task to flush our queue
// TODO: see if it's faster to not have this here
let connection = self.connection.clone();
spawn(async move {
// Send the entire batch of messages
let _ = connection.send_messages_raw(messages).await;
});
}
}
}
14 changes: 0 additions & 14 deletions broker/src/state/broker.rs

This file was deleted.

2 changes: 0 additions & 2 deletions broker/src/state/mod.rs

This file was deleted.

Empty file removed broker/src/state/user.rs
Empty file.
4 changes: 2 additions & 2 deletions client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ where
/// TODO IMPORTANT: see if we want this, or if we'd prefer `set_subscriptions()`
pub async fn subscribe(&self, topics: Vec<Topic>) -> Result<()> {
// Lock subscriptions here so we maintain parity during a reconnection
let mut subscribed_guard = self.0.inner.subscribed_topics.lock().await;
let mut subscribed_guard = self.0.inner.subscribed_topics.write().await;

// Calculate the real topics to send based on whatever's already in the set
let topics_to_send: Vec<Topic> = topics
Expand Down Expand Up @@ -163,7 +163,7 @@ where
/// If the connection or serialization has failed
pub async fn unsubscribe(&self, topics: Vec<Topic>) -> Result<()> {
// Lock subscriptions here so we maintain parity during a reconnection
let mut subscribed_guard = self.0.inner.subscribed_topics.lock().await;
let mut subscribed_guard = self.0.inner.subscribed_topics.write().await;

// Calculate the real topics to send based on whatever's already in the set
let topics_to_send: Vec<Topic> = topics
Expand Down
Loading

0 comments on commit 7cd3ac0

Please sign in to comment.