Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(kad): don't use OutboundOpenInfo #3760

Merged
merged 7 commits into from
Apr 28, 2023
76 changes: 32 additions & 44 deletions protocols/kad/src/handler_priv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ pub struct KademliaHandler<TUserData> {

/// List of outbound substreams that are waiting to become active next.
/// Contains the request we want to send, and the user data if we expect an answer.
requested_streams:
VecDeque<SubstreamProtocol<KademliaProtocolConfig, (KadRequestMsg, Option<TUserData>)>>,
pending_messages: VecDeque<(KadRequestMsg, Option<TUserData>)>,

/// List of active inbound substreams with the state they are in.
inbound_substreams: SelectAll<InboundSubstreamState<TUserData>>,
Expand Down Expand Up @@ -499,27 +498,30 @@ where
inbound_substreams: Default::default(),
outbound_substreams: Default::default(),
num_requested_outbound_streams: 0,
requested_streams: Default::default(),
pending_messages: Default::default(),
keep_alive,
protocol_status: ProtocolStatus::Unconfirmed,
}
}

fn on_fully_negotiated_outbound(
&mut self,
FullyNegotiatedOutbound {
protocol,
info: (msg, user_data),
}: FullyNegotiatedOutbound<
FullyNegotiatedOutbound { protocol, info: () }: FullyNegotiatedOutbound<
<Self as ConnectionHandler>::OutboundProtocol,
<Self as ConnectionHandler>::OutboundOpenInfo,
>,
) {
self.outbound_substreams
.push(OutboundSubstreamState::PendingSend(
protocol, msg, user_data,
));
if let Some((msg, user_data)) = self.pending_messages.pop_front() {
self.outbound_substreams
.push(OutboundSubstreamState::PendingSend(
protocol, msg, user_data,
));
} else {
debug_assert!(false, "Requested outbound stream without message")
}

self.num_requested_outbound_streams -= 1;

if let ProtocolStatus::Unconfirmed = self.protocol_status {
// Upon the first successfully negotiated substream, we know that the
// remote is configured with the same protocol name and we want
Expand Down Expand Up @@ -587,20 +589,20 @@ where
fn on_dial_upgrade_error(
&mut self,
DialUpgradeError {
info: (_, user_data),
error,
..
info: (), error, ..
}: DialUpgradeError<
<Self as ConnectionHandler>::OutboundOpenInfo,
<Self as ConnectionHandler>::OutboundProtocol,
>,
) {
// TODO: cache the fact that the remote doesn't support kademlia at all, so that we don't
// continue trying
if let Some(user_data) = user_data {

if let Some((_, Some(user_data))) = self.pending_messages.pop_front() {
self.outbound_substreams
.push(OutboundSubstreamState::ReportError(error.into(), user_data));
}

self.num_requested_outbound_streams -= 1;
}
}
Expand All @@ -614,8 +616,7 @@ where
type Error = io::Error; // TODO: better error type?
type InboundProtocol = Either<KademliaProtocolConfig, upgrade::DeniedUpgrade>;
type OutboundProtocol = KademliaProtocolConfig;
// Message of the request to send to the remote, and user data if we expect an answer.
type OutboundOpenInfo = (KadRequestMsg, Option<TUserData>);
type OutboundOpenInfo = ();
type InboundOpenInfo = ();

fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
Expand Down Expand Up @@ -645,21 +646,15 @@ where
}
KademliaHandlerIn::FindNodeReq { key, user_data } => {
let msg = KadRequestMsg::FindNode { key };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
self.pending_messages.push_back((msg, Some(user_data)));
}
KademliaHandlerIn::FindNodeRes {
closer_peers,
request_id,
} => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
KademliaHandlerIn::GetProvidersReq { key, user_data } => {
let msg = KadRequestMsg::GetProviders { key };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
self.pending_messages.push_back((msg, Some(user_data)));
}
KademliaHandlerIn::GetProvidersRes {
closer_peers,
Expand All @@ -674,24 +669,15 @@ where
),
KademliaHandlerIn::AddProvider { key, provider } => {
let msg = KadRequestMsg::AddProvider { key, provider };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, None),
));
self.pending_messages.push_back((msg, None));
}
KademliaHandlerIn::GetRecord { key, user_data } => {
let msg = KadRequestMsg::GetValue { key };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
self.pending_messages.push_back((msg, Some(user_data)));
}
KademliaHandlerIn::PutRecord { record, user_data } => {
let msg = KadRequestMsg::PutValue { record };
self.requested_streams.push_back(SubstreamProtocol::new(
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
self.pending_messages.push_back((msg, Some(user_data)));
}
KademliaHandlerIn::GetRecordRes {
record,
Expand Down Expand Up @@ -750,11 +736,13 @@ where

let num_in_progress_outbound_substreams =
self.outbound_substreams.len() + self.num_requested_outbound_streams;
if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS {
if let Some(protocol) = self.requested_streams.pop_front() {
self.num_requested_outbound_streams += 1;
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol });
}
if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS
&& self.num_requested_outbound_streams < self.pending_messages.len()
{
self.num_requested_outbound_streams += 1;
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(self.config.protocol_config.clone(), ()),
});
}

let no_streams = self.outbound_substreams.is_empty() && self.inbound_substreams.is_empty();
Expand Down Expand Up @@ -828,7 +816,7 @@ where
{
type Item = ConnectionHandlerEvent<
KademliaProtocolConfig,
(KadRequestMsg, Option<TUserData>),
(),
KademliaHandlerEvent<TUserData>,
io::Error,
>;
Expand Down Expand Up @@ -964,7 +952,7 @@ where
{
type Item = ConnectionHandlerEvent<
KademliaProtocolConfig,
(KadRequestMsg, Option<TUserData>),
(),
KademliaHandlerEvent<TUserData>,
io::Error,
>;
Expand Down