Skip to content

Commit

Permalink
feat: add udp packet ratelimiting (#8406)
Browse files Browse the repository at this point in the history
Co-authored-by: Federico Gimenez <[email protected]>
  • Loading branch information
mattsse and fgimenez authored May 27, 2024
1 parent ed926ec commit 2e47e9f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 2 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/net/discv4/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ tokio = { workspace = true, features = ["io-util", "net", "time"] }
tokio-stream.workspace = true

# misc
schnellru.workspace = true
tracing.workspace = true
thiserror.workspace = true
parking_lot.workspace = true
Expand Down
90 changes: 89 additions & 1 deletion crates/net/discv4/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ use secp256k1::SecretKey;
use std::{
cell::RefCell,
collections::{btree_map, hash_map::Entry, BTreeMap, HashMap, VecDeque},
fmt, io,
fmt,
future::poll_fn,
io,
net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4},
pin::Pin,
rc::Rc,
Expand Down Expand Up @@ -1796,7 +1798,13 @@ pub(crate) async fn send_loop(udp: Arc<UdpSocket>, rx: EgressReceiver) {
}
}

/// Rate limits the number of incoming packets from individual IPs to 1 packet/second
const MAX_INCOMING_PACKETS_PER_MINUTE_BY_IP: usize = 60usize;

/// Continuously awaits new incoming messages and sends them back through the channel.
///
/// The receive loop enforce primitive rate limiting for ips to prevent message spams from
/// individual IPs
pub(crate) async fn receive_loop(udp: Arc<UdpSocket>, tx: IngressSender, local_id: PeerId) {
let send = |event: IngressEvent| async {
let _ = tx.send(event).await.map_err(|err| {
Expand All @@ -1808,6 +1816,12 @@ pub(crate) async fn receive_loop(udp: Arc<UdpSocket>, tx: IngressSender, local_i
});
};

let mut cache = ReceiveCache::default();

// tick at half the rate of the limit
let tick = MAX_INCOMING_PACKETS_PER_MINUTE_BY_IP / 2;
let mut interval = tokio::time::interval(Duration::from_secs(tick as u64));

let mut buf = [0; MAX_PACKET_SIZE];
loop {
let res = udp.recv_from(&mut buf).await;
Expand All @@ -1817,6 +1831,12 @@ pub(crate) async fn receive_loop(udp: Arc<UdpSocket>, tx: IngressSender, local_i
send(IngressEvent::RecvError(err)).await;
}
Ok((read, remote_addr)) => {
// rate limit incoming packets by IP
if cache.inc_ip(remote_addr.ip()) > MAX_INCOMING_PACKETS_PER_MINUTE_BY_IP {
trace!(target: "discv4", ?remote_addr, "Too many incoming packets from IP.");
continue
}

let packet = &buf[..read];
match Message::decode(packet) {
Ok(packet) => {
Expand All @@ -1825,6 +1845,13 @@ pub(crate) async fn receive_loop(udp: Arc<UdpSocket>, tx: IngressSender, local_i
debug!(target: "discv4", ?remote_addr, "Received own packet.");
continue
}

// skip if we've already received the same packet
if cache.contains_packet(packet.hash) {
debug!(target: "discv4", ?remote_addr, "Received duplicate packet.");
continue
}

send(IngressEvent::Packet(remote_addr, packet)).await;
}
Err(err) => {
Expand All @@ -1834,6 +1861,67 @@ pub(crate) async fn receive_loop(udp: Arc<UdpSocket>, tx: IngressSender, local_i
}
}
}

// reset the tracked ips if the interval has passed
if poll_fn(|cx| match interval.poll_tick(cx) {
Poll::Ready(_) => Poll::Ready(true),
Poll::Pending => Poll::Ready(false),
})
.await
{
cache.tick_ips(tick);
}
}
}

/// A cache for received packets and their source address.
///
/// This is used to discard duplicated packets and rate limit messages from the same source.
struct ReceiveCache {
/// keeps track of how many messages we've received from a given IP address since the last
/// tick.
///
/// This is used to count the number of messages received from a given IP address within an
/// interval.
ip_messages: HashMap<IpAddr, usize>,
// keeps track of unique packet hashes
unique_packets: schnellru::LruMap<B256, ()>,
}

impl ReceiveCache {
/// Updates the counter for each IP address and removes IPs that have exceeded the limit.
///
/// This will decrement the counter for each IP address and remove IPs that have reached 0.
fn tick_ips(&mut self, tick: usize) {
self.ip_messages.retain(|_, count| {
if let Some(reset) = count.checked_sub(tick) {
*count = reset;
true
} else {
false
}
});
}

/// Increases the counter for the given IP address and returns the new count.
fn inc_ip(&mut self, ip: IpAddr) -> usize {
let ctn = self.ip_messages.entry(ip).or_default();
*ctn = ctn.saturating_add(1);
*ctn
}

/// Returns true if we previously received the packet
fn contains_packet(&mut self, hash: B256) -> bool {
!self.unique_packets.insert(hash, ())
}
}

impl Default for ReceiveCache {
fn default() -> Self {
Self {
ip_messages: Default::default(),
unique_packets: schnellru::LruMap::new(schnellru::ByLength::new(32)),
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/net/discv4/src/proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub enum Message {

impl Message {
/// Returns the id for this type
pub fn msg_type(&self) -> MessageId {
pub const fn msg_type(&self) -> MessageId {
match self {
Message::Ping(_) => MessageId::Ping,
Message::Pong(_) => MessageId::Pong,
Expand Down

0 comments on commit 2e47e9f

Please sign in to comment.