refactor(kad): don't use OutboundOpenInfo

As part of pushing #3268 forward, remove the use of `OutboundOpenInfo` from `libp2p-kad`.

Related #3268.

Pull-Request: #3760.
This commit is contained in:
Thomas Eizinger
2023-04-28 15:37:06 +02:00
committed by GitHub
parent 4ebb4d0a30
commit 99ad3b6eaf

View File

@ -67,8 +67,7 @@ pub struct KademliaHandler<TUserData> {
/// List of outbound substreams that are waiting to become active next. /// List of outbound substreams that are waiting to become active next.
/// Contains the request we want to send, and the user data if we expect an answer. /// Contains the request we want to send, and the user data if we expect an answer.
requested_streams: pending_messages: VecDeque<(KadRequestMsg, Option<TUserData>)>,
VecDeque<SubstreamProtocol<KademliaProtocolConfig, (KadRequestMsg, Option<TUserData>)>>,
/// List of active inbound substreams with the state they are in. /// List of active inbound substreams with the state they are in.
inbound_substreams: SelectAll<InboundSubstreamState<TUserData>>, inbound_substreams: SelectAll<InboundSubstreamState<TUserData>>,
@ -499,7 +498,7 @@ where
inbound_substreams: Default::default(), inbound_substreams: Default::default(),
outbound_substreams: Default::default(), outbound_substreams: Default::default(),
num_requested_outbound_streams: 0, num_requested_outbound_streams: 0,
requested_streams: Default::default(), pending_messages: Default::default(),
keep_alive, keep_alive,
protocol_status: ProtocolStatus::Unconfirmed, protocol_status: ProtocolStatus::Unconfirmed,
} }
@ -507,19 +506,22 @@ where
fn on_fully_negotiated_outbound( fn on_fully_negotiated_outbound(
&mut self, &mut self,
FullyNegotiatedOutbound { FullyNegotiatedOutbound { protocol, info: () }: FullyNegotiatedOutbound<
protocol,
info: (msg, user_data),
}: FullyNegotiatedOutbound<
<Self as ConnectionHandler>::OutboundProtocol, <Self as ConnectionHandler>::OutboundProtocol,
<Self as ConnectionHandler>::OutboundOpenInfo, <Self as ConnectionHandler>::OutboundOpenInfo,
>, >,
) { ) {
self.outbound_substreams if let Some((msg, user_data)) = self.pending_messages.pop_front() {
.push(OutboundSubstreamState::PendingSend( self.outbound_substreams
protocol, msg, user_data, .push(OutboundSubstreamState::PendingSend(
)); protocol, msg, user_data,
));
} else {
debug_assert!(false, "Requested outbound stream without message")
}
self.num_requested_outbound_streams -= 1; self.num_requested_outbound_streams -= 1;
if let ProtocolStatus::Unconfirmed = self.protocol_status { if let ProtocolStatus::Unconfirmed = self.protocol_status {
// Upon the first successfully negotiated substream, we know that the // Upon the first successfully negotiated substream, we know that the
// remote is configured with the same protocol name and we want // remote is configured with the same protocol name and we want
@ -587,9 +589,7 @@ where
fn on_dial_upgrade_error( fn on_dial_upgrade_error(
&mut self, &mut self,
DialUpgradeError { DialUpgradeError {
info: (_, user_data), info: (), error, ..
error,
..
}: DialUpgradeError< }: DialUpgradeError<
<Self as ConnectionHandler>::OutboundOpenInfo, <Self as ConnectionHandler>::OutboundOpenInfo,
<Self as ConnectionHandler>::OutboundProtocol, <Self as ConnectionHandler>::OutboundProtocol,
@ -597,10 +597,12 @@ where
) { ) {
// TODO: cache the fact that the remote doesn't support kademlia at all, so that we don't // TODO: cache the fact that the remote doesn't support kademlia at all, so that we don't
// continue trying // continue trying
if let Some(user_data) = user_data {
if let Some((_, Some(user_data))) = self.pending_messages.pop_front() {
self.outbound_substreams self.outbound_substreams
.push(OutboundSubstreamState::ReportError(error.into(), user_data)); .push(OutboundSubstreamState::ReportError(error.into(), user_data));
} }
self.num_requested_outbound_streams -= 1; self.num_requested_outbound_streams -= 1;
} }
} }
@ -614,8 +616,7 @@ where
type Error = io::Error; // TODO: better error type? type Error = io::Error; // TODO: better error type?
type InboundProtocol = Either<KademliaProtocolConfig, upgrade::DeniedUpgrade>; type InboundProtocol = Either<KademliaProtocolConfig, upgrade::DeniedUpgrade>;
type OutboundProtocol = KademliaProtocolConfig; type OutboundProtocol = KademliaProtocolConfig;
// Message of the request to send to the remote, and user data if we expect an answer. type OutboundOpenInfo = ();
type OutboundOpenInfo = (KadRequestMsg, Option<TUserData>);
type InboundOpenInfo = (); type InboundOpenInfo = ();
fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> { fn listen_protocol(&self) -> SubstreamProtocol<Self::InboundProtocol, Self::InboundOpenInfo> {
@ -645,10 +646,7 @@ where
} }
KademliaHandlerIn::FindNodeReq { key, user_data } => { KademliaHandlerIn::FindNodeReq { key, user_data } => {
let msg = KadRequestMsg::FindNode { key }; let msg = KadRequestMsg::FindNode { key };
self.requested_streams.push_back(SubstreamProtocol::new( self.pending_messages.push_back((msg, Some(user_data)));
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
} }
KademliaHandlerIn::FindNodeRes { KademliaHandlerIn::FindNodeRes {
closer_peers, closer_peers,
@ -656,10 +654,7 @@ where
} => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }), } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }),
KademliaHandlerIn::GetProvidersReq { key, user_data } => { KademliaHandlerIn::GetProvidersReq { key, user_data } => {
let msg = KadRequestMsg::GetProviders { key }; let msg = KadRequestMsg::GetProviders { key };
self.requested_streams.push_back(SubstreamProtocol::new( self.pending_messages.push_back((msg, Some(user_data)));
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
} }
KademliaHandlerIn::GetProvidersRes { KademliaHandlerIn::GetProvidersRes {
closer_peers, closer_peers,
@ -674,24 +669,15 @@ where
), ),
KademliaHandlerIn::AddProvider { key, provider } => { KademliaHandlerIn::AddProvider { key, provider } => {
let msg = KadRequestMsg::AddProvider { key, provider }; let msg = KadRequestMsg::AddProvider { key, provider };
self.requested_streams.push_back(SubstreamProtocol::new( self.pending_messages.push_back((msg, None));
self.config.protocol_config.clone(),
(msg, None),
));
} }
KademliaHandlerIn::GetRecord { key, user_data } => { KademliaHandlerIn::GetRecord { key, user_data } => {
let msg = KadRequestMsg::GetValue { key }; let msg = KadRequestMsg::GetValue { key };
self.requested_streams.push_back(SubstreamProtocol::new( self.pending_messages.push_back((msg, Some(user_data)));
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
} }
KademliaHandlerIn::PutRecord { record, user_data } => { KademliaHandlerIn::PutRecord { record, user_data } => {
let msg = KadRequestMsg::PutValue { record }; let msg = KadRequestMsg::PutValue { record };
self.requested_streams.push_back(SubstreamProtocol::new( self.pending_messages.push_back((msg, Some(user_data)));
self.config.protocol_config.clone(),
(msg, Some(user_data)),
));
} }
KademliaHandlerIn::GetRecordRes { KademliaHandlerIn::GetRecordRes {
record, record,
@ -750,11 +736,13 @@ where
let num_in_progress_outbound_substreams = let num_in_progress_outbound_substreams =
self.outbound_substreams.len() + self.num_requested_outbound_streams; self.outbound_substreams.len() + self.num_requested_outbound_streams;
if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS { if num_in_progress_outbound_substreams < MAX_NUM_SUBSTREAMS
if let Some(protocol) = self.requested_streams.pop_front() { && self.num_requested_outbound_streams < self.pending_messages.len()
self.num_requested_outbound_streams += 1; {
return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }); self.num_requested_outbound_streams += 1;
} return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest {
protocol: SubstreamProtocol::new(self.config.protocol_config.clone(), ()),
});
} }
let no_streams = self.outbound_substreams.is_empty() && self.inbound_substreams.is_empty(); let no_streams = self.outbound_substreams.is_empty() && self.inbound_substreams.is_empty();
@ -828,7 +816,7 @@ where
{ {
type Item = ConnectionHandlerEvent< type Item = ConnectionHandlerEvent<
KademliaProtocolConfig, KademliaProtocolConfig,
(KadRequestMsg, Option<TUserData>), (),
KademliaHandlerEvent<TUserData>, KademliaHandlerEvent<TUserData>,
io::Error, io::Error,
>; >;
@ -964,7 +952,7 @@ where
{ {
type Item = ConnectionHandlerEvent< type Item = ConnectionHandlerEvent<
KademliaProtocolConfig, KademliaProtocolConfig,
(KadRequestMsg, Option<TUserData>), (),
KademliaHandlerEvent<TUserData>, KademliaHandlerEvent<TUserData>,
io::Error, io::Error,
>; >;