diff --git a/protocols/kad/src/handler.rs b/protocols/kad/src/handler.rs index 5be9ae17..a2efb906 100644 --- a/protocols/kad/src/handler.rs +++ b/protocols/kad/src/handler.rs @@ -24,6 +24,7 @@ use crate::protocol::{ }; use crate::record::{self, Record}; use futures::prelude::*; +use futures::stream::SelectAll; use instant::Instant; use libp2p_core::{ either::EitherOutput, @@ -35,6 +36,7 @@ use libp2p_swarm::{ KeepAlive, NegotiatedSubstream, SubstreamProtocol, }; use log::trace; +use std::task::Waker; use std::{ error, fmt, io, marker::PhantomData, pin::Pin, task::Context, task::Poll, time::Duration, }; @@ -56,7 +58,9 @@ impl KademliaHandlerProto { } } -impl IntoConnectionHandler for KademliaHandlerProto { +impl IntoConnectionHandler + for KademliaHandlerProto +{ type Handler = KademliaHandler; fn into_handler(self, remote_peer_id: &PeerId, endpoint: &ConnectedPoint) -> Self::Handler { @@ -87,10 +91,10 @@ pub struct KademliaHandler { next_connec_unique_id: UniqueConnecId, /// List of active outbound substreams with the state they are in. - outbound_substreams: Vec>, + outbound_substreams: SelectAll>, /// List of active inbound substreams with the state they are in. - inbound_substreams: Vec, + inbound_substreams: SelectAll>, /// Until when to keep the connection alive. keep_alive: KeepAlive, @@ -137,7 +141,7 @@ pub struct KademliaHandlerConfig { enum OutboundSubstreamState { /// We haven't started opening the outgoing substream yet. /// Contains the request we want to send, and the user data if we expect an answer. - PendingOpen(KadRequestMsg, Option), + PendingOpen(SubstreamProtocol)>), /// Waiting to send a message to the remote. PendingSend( KadOutStreamSink, @@ -153,10 +157,13 @@ enum OutboundSubstreamState { ReportError(KademliaHandlerQueryErr, TUserData), /// The substream is being closed. Closing(KadOutStreamSink), + /// The substream is complete and will not perform any more work. + Done, + Poisoned, } /// State of an active inbound substream. -enum InboundSubstreamState { +enum InboundSubstreamState { /// Waiting for a request from the remote. WaitingMessage { /// Whether it is the first message to be awaited on this stream. @@ -165,7 +172,11 @@ enum InboundSubstreamState { substream: KadInStreamSink, }, /// Waiting for the user to send a [`KademliaHandlerIn`] event containing the response. - WaitingUser(UniqueConnecId, KadInStreamSink), + WaitingUser( + UniqueConnecId, + KadInStreamSink, + Option, + ), /// Waiting to send an answer back to the remote. PendingSend( UniqueConnecId, @@ -176,48 +187,63 @@ enum InboundSubstreamState { PendingFlush(UniqueConnecId, KadInStreamSink), /// The substream is being closed. Closing(KadInStreamSink), + /// The substream was cancelled in favor of a new one. + Cancelled, + + Poisoned { + phantom: PhantomData, + }, } -impl OutboundSubstreamState { - /// Tries to close the substream. - /// - /// If the substream is not ready to be closed, returns it back. - fn try_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { - match self { - OutboundSubstreamState::PendingOpen(_, _) - | OutboundSubstreamState::ReportError(_, _) => Poll::Ready(()), - OutboundSubstreamState::PendingSend(ref mut stream, _, _) - | OutboundSubstreamState::PendingFlush(ref mut stream, _) - | OutboundSubstreamState::WaitingAnswer(ref mut stream, _) - | OutboundSubstreamState::Closing(ref mut stream) => { - match Sink::poll_close(Pin::new(stream), cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, +impl InboundSubstreamState { + fn try_answer_with( + &mut self, + id: KademliaRequestId, + msg: KadResponseMsg, + ) -> Result<(), KadResponseMsg> { + match std::mem::replace( + self, + InboundSubstreamState::Poisoned { + phantom: PhantomData, + }, + ) { + InboundSubstreamState::WaitingUser(conn_id, substream, mut waker) + if conn_id == id.connec_unique_id => + { + *self = InboundSubstreamState::PendingSend(conn_id, substream, msg); + + if let Some(waker) = waker.take() { + waker.wake(); } + + Ok(()) + } + other => { + *self = other; + + Err(msg) } } } -} -impl InboundSubstreamState { - /// Tries to close the substream. - /// - /// If the substream is not ready to be closed, returns it back. - fn try_close(&mut self, cx: &mut Context<'_>) -> Poll<()> { - match self { - InboundSubstreamState::WaitingMessage { - substream: ref mut stream, - .. + fn close(&mut self) { + match std::mem::replace( + self, + InboundSubstreamState::Poisoned { + phantom: PhantomData, + }, + ) { + InboundSubstreamState::WaitingMessage { substream, .. } + | InboundSubstreamState::WaitingUser(_, substream, _) + | InboundSubstreamState::PendingSend(_, substream, _) + | InboundSubstreamState::PendingFlush(_, substream) + | InboundSubstreamState::Closing(substream) => { + *self = InboundSubstreamState::Closing(substream); } - | InboundSubstreamState::WaitingUser(_, ref mut stream) - | InboundSubstreamState::PendingSend(_, ref mut stream, _) - | InboundSubstreamState::PendingFlush(_, ref mut stream) - | InboundSubstreamState::Closing(ref mut stream) => { - match Sink::poll_close(Pin::new(stream), cx) { - Poll::Ready(_) => Poll::Ready(()), - Poll::Pending => Poll::Pending, - } + InboundSubstreamState::Cancelled => { + *self = InboundSubstreamState::Cancelled; } + InboundSubstreamState::Poisoned { .. } => unreachable!(), } } } @@ -469,7 +495,7 @@ pub enum KademliaHandlerIn { /// Unique identifier for a request. Must be passed back in order to answer a request from /// the remote. -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] pub struct KademliaRequestId { /// Unique identifier for an incoming connection. connec_unique_id: UniqueConnecId, @@ -479,7 +505,10 @@ pub struct KademliaRequestId { #[derive(Debug, Copy, Clone, PartialEq, Eq)] struct UniqueConnecId(u64); -impl KademliaHandler { +impl KademliaHandler +where + TUserData: Unpin, +{ /// Create a [`KademliaHandler`] using the given configuration. pub fn new( config: KademliaHandlerConfig, @@ -503,7 +532,7 @@ impl KademliaHandler { impl ConnectionHandler for KademliaHandler where - TUserData: Clone + fmt::Debug + Send + 'static, + TUserData: Clone + fmt::Debug + Send + 'static + Unpin, { type InEvent = KademliaHandlerIn; type OutEvent = KademliaHandlerEvent; @@ -560,14 +589,14 @@ where } if self.inbound_substreams.len() == MAX_NUM_INBOUND_SUBSTREAMS { - if let Some(position) = self.inbound_substreams.iter().position(|s| { + if let Some(s) = self.inbound_substreams.iter_mut().find(|s| { matches!( s, // An inbound substream waiting to be reused. InboundSubstreamState::WaitingMessage { first: false, .. } ) }) { - self.inbound_substreams.remove(position); + *s = InboundSubstreamState::Cancelled; log::warn!( "New inbound substream to {:?} exceeds inbound substream limit. \ Removed older substream waiting to be reused.", @@ -597,153 +626,93 @@ where fn inject_event(&mut self, message: KademliaHandlerIn) { match message { KademliaHandlerIn::Reset(request_id) => { - let pos = self + if let Some(state) = self .inbound_substreams - .iter() - .position(|state| match state { - InboundSubstreamState::WaitingUser(conn_id, _) => { + .iter_mut() + .find(|state| match state { + InboundSubstreamState::WaitingUser(conn_id, _, _) => { conn_id == &request_id.connec_unique_id } _ => false, - }); - if let Some(pos) = pos { - // TODO: we don't properly close down the substream - let waker = futures::task::noop_waker(); - let mut cx = Context::from_waker(&waker); - let _ = self.inbound_substreams.remove(pos).try_close(&mut cx); + }) + { + state.close(); } } KademliaHandlerIn::FindNodeReq { key, user_data } => { let msg = KadRequestMsg::FindNode { key }; self.outbound_substreams - .push(OutboundSubstreamState::PendingOpen(msg, Some(user_data))); + .push(OutboundSubstreamState::PendingOpen(SubstreamProtocol::new( + self.config.protocol_config.clone(), + (msg, Some(user_data)), + ))); } KademliaHandlerIn::FindNodeRes { closer_peers, request_id, - } => { - let pos = self - .inbound_substreams - .iter() - .position(|state| match state { - InboundSubstreamState::WaitingUser(ref conn_id, _) => { - conn_id == &request_id.connec_unique_id - } - _ => false, - }); - - if let Some(pos) = pos { - let (conn_id, substream) = match self.inbound_substreams.remove(pos) { - InboundSubstreamState::WaitingUser(conn_id, substream) => { - (conn_id, substream) - } - _ => unreachable!(), - }; - - let msg = KadResponseMsg::FindNode { closer_peers }; - self.inbound_substreams - .push(InboundSubstreamState::PendingSend(conn_id, substream, msg)); - } - } + } => self.answer_pending_request(request_id, KadResponseMsg::FindNode { closer_peers }), KademliaHandlerIn::GetProvidersReq { key, user_data } => { let msg = KadRequestMsg::GetProviders { key }; self.outbound_substreams - .push(OutboundSubstreamState::PendingOpen(msg, Some(user_data))); + .push(OutboundSubstreamState::PendingOpen(SubstreamProtocol::new( + self.config.protocol_config.clone(), + (msg, Some(user_data)), + ))); } KademliaHandlerIn::GetProvidersRes { closer_peers, provider_peers, request_id, - } => { - let pos = self - .inbound_substreams - .iter() - .position(|state| matches!(state, InboundSubstreamState::WaitingUser(ref conn_id, _) if conn_id == &request_id.connec_unique_id)); - - if let Some(pos) = pos { - let (conn_id, substream) = match self.inbound_substreams.remove(pos) { - InboundSubstreamState::WaitingUser(conn_id, substream) => { - (conn_id, substream) - } - _ => unreachable!(), - }; - - let msg = KadResponseMsg::GetProviders { - closer_peers, - provider_peers, - }; - self.inbound_substreams - .push(InboundSubstreamState::PendingSend(conn_id, substream, msg)); - } - } + } => self.answer_pending_request( + request_id, + KadResponseMsg::GetProviders { + closer_peers, + provider_peers, + }, + ), KademliaHandlerIn::AddProvider { key, provider } => { let msg = KadRequestMsg::AddProvider { key, provider }; self.outbound_substreams - .push(OutboundSubstreamState::PendingOpen(msg, None)); + .push(OutboundSubstreamState::PendingOpen(SubstreamProtocol::new( + self.config.protocol_config.clone(), + (msg, None), + ))); } KademliaHandlerIn::GetRecord { key, user_data } => { let msg = KadRequestMsg::GetValue { key }; self.outbound_substreams - .push(OutboundSubstreamState::PendingOpen(msg, Some(user_data))); + .push(OutboundSubstreamState::PendingOpen(SubstreamProtocol::new( + self.config.protocol_config.clone(), + (msg, Some(user_data)), + ))); } KademliaHandlerIn::PutRecord { record, user_data } => { let msg = KadRequestMsg::PutValue { record }; self.outbound_substreams - .push(OutboundSubstreamState::PendingOpen(msg, Some(user_data))); + .push(OutboundSubstreamState::PendingOpen(SubstreamProtocol::new( + self.config.protocol_config.clone(), + (msg, Some(user_data)), + ))); } KademliaHandlerIn::GetRecordRes { record, closer_peers, request_id, } => { - let pos = self - .inbound_substreams - .iter() - .position(|state| match state { - InboundSubstreamState::WaitingUser(ref conn_id, _) => { - conn_id == &request_id.connec_unique_id - } - _ => false, - }); - - if let Some(pos) = pos { - let (conn_id, substream) = match self.inbound_substreams.remove(pos) { - InboundSubstreamState::WaitingUser(conn_id, substream) => { - (conn_id, substream) - } - _ => unreachable!(), - }; - - let msg = KadResponseMsg::GetValue { + self.answer_pending_request( + request_id, + KadResponseMsg::GetValue { record, closer_peers, - }; - self.inbound_substreams - .push(InboundSubstreamState::PendingSend(conn_id, substream, msg)); - } + }, + ); } KademliaHandlerIn::PutRecordRes { key, request_id, value, } => { - let pos = self - .inbound_substreams - .iter() - .position(|state| matches!(state, InboundSubstreamState::WaitingUser(ref conn_id, _) if conn_id == &request_id.connec_unique_id)); - - if let Some(pos) = pos { - let (conn_id, substream) = match self.inbound_substreams.remove(pos) { - InboundSubstreamState::WaitingUser(conn_id, substream) => { - (conn_id, substream) - } - _ => unreachable!(), - }; - - let msg = KadResponseMsg::PutValue { key, value }; - self.inbound_substreams - .push(InboundSubstreamState::PendingSend(conn_id, substream, msg)); - } + self.answer_pending_request(request_id, KadResponseMsg::PutValue { key, value }); } } } @@ -776,10 +745,6 @@ where Self::Error, >, > { - if self.outbound_substreams.is_empty() && self.inbound_substreams.is_empty() { - return Poll::Pending; - } - if let ProtocolStatus::Confirmed = self.protocol_status { self.protocol_status = ProtocolStatus::Reported; return Poll::Ready(ConnectionHandlerEvent::Custom( @@ -789,69 +754,12 @@ where )); } - // We remove each element from `outbound_substreams` one by one and add them back. - for n in (0..self.outbound_substreams.len()).rev() { - let mut substream = self.outbound_substreams.swap_remove(n); - - loop { - match advance_outbound_substream(substream, self.config.protocol_config.clone(), cx) - { - (Some(new_state), Some(event), _) => { - self.outbound_substreams.push(new_state); - return Poll::Ready(event); - } - (None, Some(event), _) => { - if self.outbound_substreams.is_empty() { - self.keep_alive = - KeepAlive::Until(Instant::now() + self.config.idle_timeout); - } - return Poll::Ready(event); - } - (Some(new_state), None, false) => { - self.outbound_substreams.push(new_state); - break; - } - (Some(new_state), None, true) => { - substream = new_state; - continue; - } - (None, None, _) => { - break; - } - } - } + if let Poll::Ready(Some(event)) = self.outbound_substreams.poll_next_unpin(cx) { + return Poll::Ready(event); } - // We remove each element from `inbound_substreams` one by one and add them back. - for n in (0..self.inbound_substreams.len()).rev() { - let mut substream = self.inbound_substreams.swap_remove(n); - - loop { - match advance_inbound_substream(substream, cx) { - (Some(new_state), Some(event), _) => { - self.inbound_substreams.push(new_state); - return Poll::Ready(event); - } - (None, Some(event), _) => { - if self.inbound_substreams.is_empty() { - self.keep_alive = - KeepAlive::Until(Instant::now() + self.config.idle_timeout); - } - return Poll::Ready(event); - } - (Some(new_state), None, false) => { - self.inbound_substreams.push(new_state); - break; - } - (Some(new_state), None, true) => { - substream = new_state; - continue; - } - (None, None, _) => { - break; - } - } - } + if let Poll::Ready(Some(event)) = self.inbound_substreams.poll_next_unpin(cx) { + return Poll::Ready(event); } if self.outbound_substreams.is_empty() && self.inbound_substreams.is_empty() { @@ -865,6 +773,24 @@ where } } +impl KademliaHandler +where + TUserData: 'static + Clone + Send + Unpin + fmt::Debug, +{ + fn answer_pending_request(&mut self, request_id: KademliaRequestId, mut msg: KadResponseMsg) { + for state in self.inbound_substreams.iter_mut() { + match state.try_answer_with(request_id, msg) { + Ok(()) => return, + Err(m) => { + msg = m; + } + } + } + + debug_assert!(false, "Cannot find inbound substream for {request_id:?}") + } +} + impl Default for KademliaHandlerConfig { fn default() -> Self { KademliaHandlerConfig { @@ -875,247 +801,245 @@ impl Default for KademliaHandlerConfig { } } -/// Advances one outbound substream. -/// -/// Returns the new state for that substream, an event to generate, and whether the substream -/// should be polled again. -fn advance_outbound_substream( - state: OutboundSubstreamState, - upgrade: KademliaProtocolConfig, - cx: &mut Context<'_>, -) -> ( - Option>, - Option< - ConnectionHandlerEvent< - KademliaProtocolConfig, - (KadRequestMsg, Option), - KademliaHandlerEvent, - io::Error, - >, - >, - bool, -) { - match state { - OutboundSubstreamState::PendingOpen(msg, user_data) => { - let ev = ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new(upgrade, (msg, user_data)), - }; - (None, Some(ev), false) - } - OutboundSubstreamState::PendingSend(mut substream, msg, user_data) => { - match Sink::poll_ready(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => match Sink::start_send(Pin::new(&mut substream), msg) { - Ok(()) => ( - Some(OutboundSubstreamState::PendingFlush(substream, user_data)), - None, - true, - ), - Err(error) => { - let event = user_data.map(|user_data| { - ConnectionHandlerEvent::Custom(KademliaHandlerEvent::QueryError { +impl Stream for OutboundSubstreamState +where + TUserData: Unpin, +{ + type Item = ConnectionHandlerEvent< + KademliaProtocolConfig, + (KadRequestMsg, Option), + KademliaHandlerEvent, + io::Error, + >; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + loop { + match std::mem::replace(this, OutboundSubstreamState::Poisoned) { + OutboundSubstreamState::PendingOpen(protocol) => { + *this = OutboundSubstreamState::Done; + return Poll::Ready(Some(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol, + })); + } + OutboundSubstreamState::PendingSend(mut substream, msg, user_data) => { + match substream.poll_ready_unpin(cx) { + Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) { + Ok(()) => { + *this = OutboundSubstreamState::PendingFlush(substream, user_data); + } + Err(error) => { + *this = OutboundSubstreamState::Done; + let event = user_data.map(|user_data| { + ConnectionHandlerEvent::Custom( + KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(error), + user_data, + }, + ) + }); + + return Poll::Ready(event); + } + }, + Poll::Pending => { + *this = OutboundSubstreamState::PendingSend(substream, msg, user_data); + return Poll::Pending; + } + Poll::Ready(Err(error)) => { + *this = OutboundSubstreamState::Done; + let event = user_data.map(|user_data| { + ConnectionHandlerEvent::Custom(KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(error), + user_data, + }) + }); + + return Poll::Ready(event); + } + } + } + OutboundSubstreamState::PendingFlush(mut substream, user_data) => { + match substream.poll_flush_unpin(cx) { + Poll::Ready(Ok(())) => { + if let Some(user_data) = user_data { + *this = OutboundSubstreamState::WaitingAnswer(substream, user_data); + } else { + *this = OutboundSubstreamState::Closing(substream); + } + } + Poll::Pending => { + *this = OutboundSubstreamState::PendingFlush(substream, user_data); + return Poll::Pending; + } + Poll::Ready(Err(error)) => { + *this = OutboundSubstreamState::Done; + let event = user_data.map(|user_data| { + ConnectionHandlerEvent::Custom(KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io(error), + user_data, + }) + }); + + return Poll::Ready(event); + } + } + } + OutboundSubstreamState::WaitingAnswer(mut substream, user_data) => { + match substream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(msg))) => { + *this = OutboundSubstreamState::Closing(substream); + let event = process_kad_response(msg, user_data); + + return Poll::Ready(Some(ConnectionHandlerEvent::Custom(event))); + } + Poll::Pending => { + *this = OutboundSubstreamState::WaitingAnswer(substream, user_data); + return Poll::Pending; + } + Poll::Ready(Some(Err(error))) => { + *this = OutboundSubstreamState::Done; + let event = KademliaHandlerEvent::QueryError { error: KademliaHandlerQueryErr::Io(error), user_data, - }) - }); + }; - (None, event, false) + return Poll::Ready(Some(ConnectionHandlerEvent::Custom(event))); + } + Poll::Ready(None) => { + *this = OutboundSubstreamState::Done; + let event = KademliaHandlerEvent::QueryError { + error: KademliaHandlerQueryErr::Io( + io::ErrorKind::UnexpectedEof.into(), + ), + user_data, + }; + + return Poll::Ready(Some(ConnectionHandlerEvent::Custom(event))); + } + } + } + OutboundSubstreamState::ReportError(error, user_data) => { + *this = OutboundSubstreamState::Done; + let event = KademliaHandlerEvent::QueryError { error, user_data }; + + return Poll::Ready(Some(ConnectionHandlerEvent::Custom(event))); + } + OutboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) { + Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None), + Poll::Pending => { + *this = OutboundSubstreamState::Closing(stream); + return Poll::Pending; } }, - Poll::Pending => ( - Some(OutboundSubstreamState::PendingSend( - substream, msg, user_data, - )), - None, - false, - ), - Poll::Ready(Err(error)) => { - let event = user_data.map(|user_data| { - ConnectionHandlerEvent::Custom(KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(error), - user_data, - }) - }); - - (None, event, false) + OutboundSubstreamState::Done => { + *this = OutboundSubstreamState::Done; + return Poll::Ready(None); } - } - } - OutboundSubstreamState::PendingFlush(mut substream, user_data) => { - match Sink::poll_flush(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => { - if let Some(user_data) = user_data { - ( - Some(OutboundSubstreamState::WaitingAnswer(substream, user_data)), - None, - true, - ) - } else { - (Some(OutboundSubstreamState::Closing(substream)), None, true) - } - } - Poll::Pending => ( - Some(OutboundSubstreamState::PendingFlush(substream, user_data)), - None, - false, - ), - Poll::Ready(Err(error)) => { - let event = user_data.map(|user_data| { - ConnectionHandlerEvent::Custom(KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(error), - user_data, - }) - }); - - (None, event, false) - } - } - } - OutboundSubstreamState::WaitingAnswer(mut substream, user_data) => { - match Stream::poll_next(Pin::new(&mut substream), cx) { - Poll::Ready(Some(Ok(msg))) => { - let new_state = OutboundSubstreamState::Closing(substream); - let event = process_kad_response(msg, user_data); - ( - Some(new_state), - Some(ConnectionHandlerEvent::Custom(event)), - true, - ) - } - Poll::Pending => ( - Some(OutboundSubstreamState::WaitingAnswer(substream, user_data)), - None, - false, - ), - Poll::Ready(Some(Err(error))) => { - let event = KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(error), - user_data, - }; - (None, Some(ConnectionHandlerEvent::Custom(event)), false) - } - Poll::Ready(None) => { - let event = KademliaHandlerEvent::QueryError { - error: KademliaHandlerQueryErr::Io(io::ErrorKind::UnexpectedEof.into()), - user_data, - }; - (None, Some(ConnectionHandlerEvent::Custom(event)), false) - } - } - } - OutboundSubstreamState::ReportError(error, user_data) => { - let event = KademliaHandlerEvent::QueryError { error, user_data }; - (None, Some(ConnectionHandlerEvent::Custom(event)), false) - } - OutboundSubstreamState::Closing(mut stream) => { - match Sink::poll_close(Pin::new(&mut stream), cx) { - Poll::Ready(Ok(())) => (None, None, false), - Poll::Pending => (Some(OutboundSubstreamState::Closing(stream)), None, false), - Poll::Ready(Err(_)) => (None, None, false), + OutboundSubstreamState::Poisoned => unreachable!(), } } } } -/// Advances one inbound substream. -/// -/// Returns the new state for that substream, an event to generate, and whether the substream -/// should be polled again. -fn advance_inbound_substream( - state: InboundSubstreamState, - cx: &mut Context<'_>, -) -> ( - Option, - Option< - ConnectionHandlerEvent< - KademliaProtocolConfig, - (KadRequestMsg, Option), - KademliaHandlerEvent, - io::Error, - >, - >, - bool, -) { - match state { - InboundSubstreamState::WaitingMessage { - first, - connection_id, - mut substream, - } => match Stream::poll_next(Pin::new(&mut substream), cx) { - Poll::Ready(Some(Ok(msg))) => { - if let Ok(ev) = process_kad_request(msg, connection_id) { - ( - Some(InboundSubstreamState::WaitingUser(connection_id, substream)), - Some(ConnectionHandlerEvent::Custom(ev)), - false, - ) - } else { - (Some(InboundSubstreamState::Closing(substream)), None, true) - } - } - Poll::Pending => ( - Some(InboundSubstreamState::WaitingMessage { + +impl Stream for InboundSubstreamState +where + TUserData: Unpin, +{ + type Item = ConnectionHandlerEvent< + KademliaProtocolConfig, + (KadRequestMsg, Option), + KademliaHandlerEvent, + io::Error, + >; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + loop { + match std::mem::replace( + this, + Self::Poisoned { + phantom: PhantomData, + }, + ) { + InboundSubstreamState::WaitingMessage { first, connection_id, - substream, - }), - None, - false, - ), - Poll::Ready(None) => { - trace!("Inbound substream: EOF"); - (None, None, false) - } - Poll::Ready(Some(Err(e))) => { - trace!("Inbound substream error: {:?}", e); - (None, None, false) - } - }, - InboundSubstreamState::WaitingUser(id, substream) => ( - Some(InboundSubstreamState::WaitingUser(id, substream)), - None, - false, - ), - InboundSubstreamState::PendingSend(id, mut substream, msg) => { - match Sink::poll_ready(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => match Sink::start_send(Pin::new(&mut substream), msg) { - Ok(()) => ( - Some(InboundSubstreamState::PendingFlush(id, substream)), - None, - true, - ), - Err(_) => (None, None, false), + mut substream, + } => match substream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(msg))) => { + if let Ok(ev) = process_kad_request(msg, connection_id) { + *this = + InboundSubstreamState::WaitingUser(connection_id, substream, None); + return Poll::Ready(Some(ConnectionHandlerEvent::Custom(ev))); + } else { + *this = InboundSubstreamState::Closing(substream); + } + } + Poll::Pending => { + *this = InboundSubstreamState::WaitingMessage { + first, + connection_id, + substream, + }; + return Poll::Pending; + } + Poll::Ready(None) => { + return Poll::Ready(None); + } + Poll::Ready(Some(Err(e))) => { + trace!("Inbound substream error: {:?}", e); + return Poll::Ready(None); + } }, - Poll::Pending => ( - Some(InboundSubstreamState::PendingSend(id, substream, msg)), - None, - false, - ), - Poll::Ready(Err(_)) => (None, None, false), - } - } - InboundSubstreamState::PendingFlush(id, mut substream) => { - match Sink::poll_flush(Pin::new(&mut substream), cx) { - Poll::Ready(Ok(())) => ( - Some(InboundSubstreamState::WaitingMessage { - first: false, - connection_id: id, - substream, - }), - None, - true, - ), - Poll::Pending => ( - Some(InboundSubstreamState::PendingFlush(id, substream)), - None, - false, - ), - Poll::Ready(Err(_)) => (None, None, false), - } - } - InboundSubstreamState::Closing(mut stream) => { - match Sink::poll_close(Pin::new(&mut stream), cx) { - Poll::Ready(Ok(())) => (None, None, false), - Poll::Pending => (Some(InboundSubstreamState::Closing(stream)), None, false), - Poll::Ready(Err(_)) => (None, None, false), + InboundSubstreamState::WaitingUser(id, substream, _) => { + *this = + InboundSubstreamState::WaitingUser(id, substream, Some(cx.waker().clone())); + + return Poll::Pending; + } + InboundSubstreamState::PendingSend(id, mut substream, msg) => { + match substream.poll_ready_unpin(cx) { + Poll::Ready(Ok(())) => match substream.start_send_unpin(msg) { + Ok(()) => { + *this = InboundSubstreamState::PendingFlush(id, substream); + } + Err(_) => return Poll::Ready(None), + }, + Poll::Pending => { + *this = InboundSubstreamState::PendingSend(id, substream, msg); + return Poll::Pending; + } + Poll::Ready(Err(_)) => return Poll::Ready(None), + } + } + InboundSubstreamState::PendingFlush(id, mut substream) => { + match substream.poll_flush_unpin(cx) { + Poll::Ready(Ok(())) => { + *this = InboundSubstreamState::WaitingMessage { + first: false, + connection_id: id, + substream, + }; + } + Poll::Pending => { + *this = InboundSubstreamState::PendingFlush(id, substream); + return Poll::Pending; + } + Poll::Ready(Err(_)) => return Poll::Ready(None), + } + } + InboundSubstreamState::Closing(mut stream) => match stream.poll_close_unpin(cx) { + Poll::Ready(Ok(())) | Poll::Ready(Err(_)) => return Poll::Ready(None), + Poll::Pending => { + *this = InboundSubstreamState::Closing(stream); + return Poll::Pending; + } + }, + InboundSubstreamState::Poisoned { .. } => unreachable!(), + InboundSubstreamState::Cancelled => return Poll::Ready(None), } } }