From c02dea8128df7d2b9c724882ba9bdaa64674ddbb Mon Sep 17 00:00:00 2001 From: Toralf Wittner Date: Thu, 30 Aug 2018 23:25:16 +0200 Subject: [PATCH] Introduce several concrete future types. (#433) * multisteam-select: introduce `DialerFuture`. * multistream-select: add more concrete futures. * multistream-select: add ListenerFuture. * multistream-select: add ListenerSelectFuture * Formatting. * Add DialerSelectFuture type alias. * Add UpgradeApplyFuture and NegotiationFuture. * In iterator wrappers also pass-through size_hint. * Minor refactoring. * Address review comments. * Add some comments. * Hide state enums in wrapping structs. --- core/src/upgrade/apply.rs | 190 ++++++-- misc/multistream-select/src/dialer_select.rs | 425 +++++++++++++----- misc/multistream-select/src/lib.rs | 5 +- .../multistream-select/src/listener_select.rs | 166 ++++--- .../multistream-select/src/protocol/dialer.rs | 38 +- .../src/protocol/listener.rs | 86 +++- misc/multistream-select/src/protocol/mod.rs | 4 +- 7 files changed, 665 insertions(+), 249 deletions(-) diff --git a/core/src/upgrade/apply.rs b/core/src/upgrade/apply.rs index e6ff6732..7389f70e 100644 --- a/core/src/upgrade/apply.rs +++ b/core/src/upgrade/apply.rs @@ -19,9 +19,9 @@ // DEALINGS IN THE SOFTWARE. use bytes::Bytes; -use futures::{prelude::*, future}; -use multistream_select; -use std::io::{Error as IoError, ErrorKind as IoErrorKind}; +use futures::{prelude::*, future::Either}; +use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; +use std::{io::{Error as IoError, ErrorKind as IoErrorKind}, mem}; use tokio_io::{AsyncRead, AsyncWrite}; use upgrade::{ConnectionUpgrade, Endpoint}; @@ -29,31 +29,97 @@ use upgrade::{ConnectionUpgrade, Endpoint}; /// /// Returns a `Future` that returns the outcome of the connection upgrade. #[inline] -pub fn apply( - connection: C, - upgrade: U, - endpoint: Endpoint, - remote_addr: Maf, -) -> impl Future +pub fn apply(conn: C, upgrade: U, e: Endpoint, remote: Maf) -> UpgradeApplyFuture where U: ConnectionUpgrade, U::NamesIter: Clone, // TODO: not elegant C: AsyncRead + AsyncWrite, { - negotiate(connection, &upgrade, endpoint) - .and_then(move |(upgrade_id, connection)| { - upgrade.upgrade(connection, upgrade_id, endpoint, remote_addr) - }) - .into_future() - .then(|val| { - match val { - Ok(_) => debug!("Successfully applied negotiated protocol"), - Err(ref err) => debug!("Failed to apply negotiated protocol: {:?}", err), - } - val - }) + UpgradeApplyFuture { + inner: UpgradeApplyState::Init { + future: negotiate(conn, &upgrade, e), + upgrade, + endpoint: e, + remote + } + } } +/// Future, returned from `apply` which performs a connection upgrade. +pub struct UpgradeApplyFuture +where + U: ConnectionUpgrade, + C: AsyncRead + AsyncWrite +{ + inner: UpgradeApplyState +} + + +enum UpgradeApplyState +where + U: ConnectionUpgrade, + C: AsyncRead + AsyncWrite +{ + Init { + future: NegotiationFuture, U::UpgradeIdentifier>, + upgrade: U, + endpoint: Endpoint, + remote: Maf + }, + Upgrade { + future: U::Future + }, + Undefined +} + +impl Future for UpgradeApplyFuture +where + U: ConnectionUpgrade, + U::NamesIter: Clone, + C: AsyncRead + AsyncWrite +{ + type Item = (U::Output, U::MultiaddrFuture); + type Error = IoError; + + fn poll(&mut self) -> Poll { + loop { + match mem::replace(&mut self.inner, UpgradeApplyState::Undefined) { + UpgradeApplyState::Init { mut future, upgrade, endpoint, remote } => { + let (upgrade_id, connection) = match future.poll()? { + Async::Ready(x) => x, + Async::NotReady => { + self.inner = UpgradeApplyState::Init { future, upgrade, endpoint, remote }; + return Ok(Async::NotReady) + } + }; + self.inner = UpgradeApplyState::Upgrade { + future: upgrade.upgrade(connection, upgrade_id, endpoint, remote) + }; + } + UpgradeApplyState::Upgrade { mut future } => { + match future.poll() { + Ok(Async::NotReady) => { + self.inner = UpgradeApplyState::Upgrade { future }; + return Ok(Async::NotReady) + } + Ok(Async::Ready(x)) => { + debug!("Successfully applied negotiated protocol"); + return Ok(Async::Ready(x)) + } + Err(e) => { + debug!("Failed to apply negotiated protocol: {:?}", e); + return Err(e) + } + } + } + UpgradeApplyState::Undefined => + panic!("UpgradeApplyState::poll called after completion") + } + } + } +} + + /// Negotiates a protocol on a stream. /// /// Returns a `Future` that returns the negotiated protocol and the stream. @@ -62,29 +128,73 @@ pub fn negotiate( connection: C, upgrade: &U, endpoint: Endpoint, -) -> impl Future +) -> NegotiationFuture, U::UpgradeIdentifier> where U: ConnectionUpgrade, U::NamesIter: Clone, // TODO: not elegant C: AsyncRead + AsyncWrite, { - let iter = upgrade - .protocol_names() - .map::<_, fn(_) -> _>(|(n, t)| (n, ::eq, t)); debug!("Starting protocol negotiation"); - - let negotiation = match endpoint { - Endpoint::Listener => future::Either::A(multistream_select::listener_select_proto(connection, iter)), - Endpoint::Dialer => future::Either::B(multistream_select::dialer_select_proto(connection, iter)), - }; - - negotiation - .map_err(|err| IoError::new(IoErrorKind::Other, err)) - .then(move |negotiated| { - match negotiated { - Ok(_) => debug!("Successfully negotiated protocol upgrade"), - Err(ref err) => debug!("Error while negotiated protocol upgrade: {:?}", err), - }; - negotiated - }) + let iter = ProtocolNames(upgrade.protocol_names()); + NegotiationFuture { + inner: match endpoint { + Endpoint::Listener => Either::A(multistream_select::listener_select_proto(connection, iter)), + Endpoint::Dialer => Either::B(multistream_select::dialer_select_proto(connection, iter)), + } + } } + + +/// Future, returned by `negotiate`, which negotiates a protocol and stream. +pub struct NegotiationFuture { + inner: Either, DialerSelectFuture> +} + +impl Future for NegotiationFuture +where + R: AsyncRead + AsyncWrite, + I: Iterator + Clone, + M: FnMut(&Bytes, &Bytes) -> bool, +{ + type Item = (P, R); + type Error = IoError; + + fn poll(&mut self) -> Poll { + match self.inner.poll() { + Ok(Async::NotReady) => Ok(Async::NotReady), + Ok(Async::Ready(x)) => { + debug!("Successfully negotiated protocol upgrade"); + Ok(Async::Ready(x)) + } + Err(e) => { + let err = IoError::new(IoErrorKind::Other, e); + debug!("Error while negotiated protocol upgrade: {:?}", err); + Err(err) + } + } + } +} + + +/// Iterator adapter which adds equality matching predicates to items. +/// Used in `NegotiationFuture`. +#[derive(Clone)] +pub struct ProtocolNames(I); + +impl Iterator for ProtocolNames +where + I: Iterator +{ + type Item = (Bytes, fn(&Bytes, &Bytes) -> bool, Id); + + fn next(&mut self) -> Option { + let f = ::eq as fn(&Bytes, &Bytes) -> bool; + self.0.next().map(|(b, id)| (b, f, id)) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + + diff --git a/misc/multistream-select/src/dialer_select.rs b/misc/multistream-select/src/dialer_select.rs index 986c1d38..4d3b230f 100644 --- a/misc/multistream-select/src/dialer_select.rs +++ b/misc/multistream-select/src/dialer_select.rs @@ -22,14 +22,16 @@ //! `multistream-select` for the dialer. use bytes::Bytes; -use futures::future::{loop_fn, result, Loop, Either}; -use futures::{Future, Sink, Stream}; +use futures::{future::Either, prelude::*, sink, stream::StreamFuture}; +use protocol::{Dialer, DialerFuture, DialerToListenerMessage, ListenerToDialerMessage}; +use std::mem; +use tokio_io::{AsyncRead, AsyncWrite}; use ProtocolChoiceError; -use protocol::Dialer; -use protocol::DialerToListenerMessage; -use protocol::ListenerToDialerMessage; -use tokio_io::{AsyncRead, AsyncWrite}; +/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer +/// either sequentially of by considering all protocols in parallel. +pub type DialerSelectFuture = + Either, P>, DialerSelectPar>; /// Helps selecting a protocol amongst the ones supported. /// @@ -43,148 +45,337 @@ use tokio_io::{AsyncRead, AsyncWrite}; /// success, the function returns the identifier (of type `P`), plus the socket which now uses that /// chosen protocol. #[inline] -pub fn dialer_select_proto( - inner: R, - protocols: I, -) -> impl Future +pub fn dialer_select_proto(inner: R, protocols: I) -> DialerSelectFuture where R: AsyncRead + AsyncWrite, - I: Iterator, + I: Iterator, M: FnMut(&Bytes, &Bytes) -> bool, { // We choose between the "serial" and "parallel" strategies based on the number of protocols. if protocols.size_hint().1.map(|n| n <= 3).unwrap_or(false) { - let fut = dialer_select_proto_serial(inner, protocols.map(|(n, _, id)| (n, id))); - Either::A(fut) + Either::A(dialer_select_proto_serial(inner, IgnoreMatchFn(protocols))) } else { - let fut = dialer_select_proto_parallel(inner, protocols); - Either::B(fut) + Either::B(dialer_select_proto_parallel(inner, protocols)) } } + +/// Iterator, which ignores match predicates of the iterator it wraps. +pub struct IgnoreMatchFn(I); + +impl Iterator for IgnoreMatchFn +where + I: Iterator +{ + type Item = (Bytes, P); + + fn next(&mut self) -> Option { + self.0.next().map(|(b, _, p)| (b, p)) + } + + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } +} + + /// Helps selecting a protocol amongst the ones supported. /// /// Same as `dialer_select_proto`. Tries protocols one by one. The iterator doesn't need to produce /// match functions, because it's not needed. -pub fn dialer_select_proto_serial( - inner: R, - mut protocols: I, -) -> impl Future +pub fn dialer_select_proto_serial(inner: R, protocols: I,) -> DialerSelectSeq where R: AsyncRead + AsyncWrite, I: Iterator, { - Dialer::new(inner).from_err().and_then(move |dialer| { - // Similar to a `loop` keyword. - loop_fn(dialer, move |dialer| { - result(protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)) - // If the `protocols` iterator produced an element, send it to the dialer - .and_then(|(proto_name, proto_value)| { - let req = DialerToListenerMessage::ProtocolRequest { - name: proto_name.clone() - }; - trace!("sending {:?}", req); - dialer.send(req) - .map(|d| (d, proto_name, proto_value)) - .from_err() - }) - // Once sent, read one element from `dialer`. - .and_then(|(dialer, proto_name, proto_value)| { - dialer - .into_future() - .map(|(msg, rest)| (msg, rest, proto_name, proto_value)) - .map_err(|(e, _)| e.into()) - }) - // Once read, analyze the response. - .and_then(|(message, rest, proto_name, proto_value)| { - trace!("received {:?}", message); - let message = message.ok_or(ProtocolChoiceError::UnexpectedMessage)?; - - match message { - ListenerToDialerMessage::ProtocolAck { ref name } - if name == &proto_name => - { - // Satisfactory response, break the loop. - Ok(Loop::Break((proto_value, rest.into_inner()))) - }, - ListenerToDialerMessage::NotAvailable => { - Ok(Loop::Continue(rest)) - }, - _ => Err(ProtocolChoiceError::UnexpectedMessage), - } - }) - }) - }) + DialerSelectSeq { + inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols } + } } + +/// Future, returned by `dialer_select_proto_serial` which selects a protocol +/// and dialer sequentially. +pub struct DialerSelectSeq { + inner: DialerSelectSeqState +} + +enum DialerSelectSeqState { + AwaitDialer { + dialer_fut: DialerFuture, + protocols: I + }, + NextProtocol { + dialer: Dialer, + protocols: I + }, + SendProtocol { + sender: sink::Send>, + proto_name: Bytes, + proto_value: P, + protocols: I + }, + AwaitProtocol { + stream: StreamFuture>, + proto_name: Bytes, + proto_value: P, + protocols: I + }, + Undefined +} + +impl Future for DialerSelectSeq +where + I: Iterator, + R: AsyncRead + AsyncWrite, +{ + type Item = (P, R); + type Error = ProtocolChoiceError; + + fn poll(&mut self) -> Poll { + loop { + match mem::replace(&mut self.inner, DialerSelectSeqState::Undefined) { + DialerSelectSeqState::AwaitDialer { mut dialer_fut, protocols } => { + let dialer = match dialer_fut.poll()? { + Async::Ready(d) => d, + Async::NotReady => { + self.inner = DialerSelectSeqState::AwaitDialer { dialer_fut, protocols }; + return Ok(Async::NotReady) + } + }; + self.inner = DialerSelectSeqState::NextProtocol { dialer, protocols } + } + DialerSelectSeqState::NextProtocol { dialer, mut protocols } => { + let (proto_name, proto_value) = + protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?; + let req = DialerToListenerMessage::ProtocolRequest { + name: proto_name.clone() + }; + trace!("sending {:?}", req); + let sender = dialer.send(req); + self.inner = DialerSelectSeqState::SendProtocol { + sender, + proto_name, + proto_value, + protocols + } + } + DialerSelectSeqState::SendProtocol { mut sender, proto_name, proto_value, protocols } => { + let dialer = match sender.poll()? { + Async::Ready(d) => d, + Async::NotReady => { + self.inner = DialerSelectSeqState::SendProtocol { + sender, + proto_name, + proto_value, + protocols + }; + return Ok(Async::NotReady) + } + }; + let stream = dialer.into_future(); + self.inner = DialerSelectSeqState::AwaitProtocol { + stream, + proto_name, + proto_value, + protocols + }; + } + DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, proto_value, protocols } => { + let (m, r) = match stream.poll() { + Ok(Async::Ready(x)) => x, + Ok(Async::NotReady) => { + self.inner = DialerSelectSeqState::AwaitProtocol { + stream, + proto_name, + proto_value, + protocols + }; + return Ok(Async::NotReady) + } + Err((e, _)) => return Err(ProtocolChoiceError::from(e)) + }; + trace!("received {:?}", m); + match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? { + ListenerToDialerMessage::ProtocolAck { ref name } if name == &proto_name => { + return Ok(Async::Ready((proto_value, r.into_inner()))) + }, + ListenerToDialerMessage::NotAvailable => { + self.inner = DialerSelectSeqState::NextProtocol { dialer: r, protocols } + } + _ => return Err(ProtocolChoiceError::UnexpectedMessage) + } + } + DialerSelectSeqState::Undefined => + panic!("DialerSelectSeqState::poll called after completion") + } + } + } +} + + /// Helps selecting a protocol amongst the ones supported. /// /// Same as `dialer_select_proto`. Queries the list of supported protocols from the remote, then /// chooses the most appropriate one. -pub fn dialer_select_proto_parallel( - inner: R, - protocols: I, -) -> impl Future +pub fn dialer_select_proto_parallel(inner: R, protocols: I) -> DialerSelectPar where R: AsyncRead + AsyncWrite, I: Iterator, M: FnMut(&Bytes, &Bytes) -> bool, { - Dialer::new(inner) - .from_err() - .and_then(move |dialer| { - trace!("requesting protocols list"); - dialer - .send(DialerToListenerMessage::ProtocolsListRequest) - .from_err() - }) - .and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e.into())) - .and_then(move |(msg, dialer)| { - trace!("protocols list response: {:?}", msg); - let list = match msg { - Some(ListenerToDialerMessage::ProtocolsListResponse { list }) => list, - _ => return Err(ProtocolChoiceError::UnexpectedMessage), - }; + DialerSelectPar { + inner: DialerSelectParState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols } + } +} - let mut found = None; - for (local_name, mut match_fn, ident) in protocols { - for remote_name in &list { - if match_fn(remote_name, &local_name) { - found = Some((remote_name.clone(), ident)); - break; + +/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in +/// parellel, by first requesting the liste of protocols supported by the remote endpoint and +/// then selecting the most appropriate one by applying a match predicate to the result. +pub struct DialerSelectPar { + inner: DialerSelectParState +} + +enum DialerSelectParState { + AwaitDialer { + dialer_fut: DialerFuture, + protocols: I + }, + SendRequest { + sender: sink::Send>, + protocols: I + }, + AwaitResponse { + stream: StreamFuture>, + protocols: I + }, + SendProtocol { + sender: sink::Send>, + proto_name: Bytes, + proto_val: P + }, + AwaitProtocol { + stream: StreamFuture>, + proto_name: Bytes, + proto_val: P + }, + Undefined +} + +impl Future for DialerSelectPar +where + I: Iterator, + M: FnMut(&Bytes, &Bytes) -> bool, + R: AsyncRead + AsyncWrite, +{ + type Item = (P, R); + type Error = ProtocolChoiceError; + + fn poll(&mut self) -> Poll { + loop { + match mem::replace(&mut self.inner, DialerSelectParState::Undefined) { + DialerSelectParState::AwaitDialer { mut dialer_fut, protocols } => { + let dialer = match dialer_fut.poll()? { + Async::Ready(d) => d, + Async::NotReady => { + self.inner = DialerSelectParState::AwaitDialer { dialer_fut, protocols }; + return Ok(Async::NotReady) + } + }; + trace!("requesting protocols list"); + let sender = dialer.send(DialerToListenerMessage::ProtocolsListRequest); + self.inner = DialerSelectParState::SendRequest { sender, protocols }; + } + DialerSelectParState::SendRequest { mut sender, protocols } => { + let dialer = match sender.poll()? { + Async::Ready(d) => d, + Async::NotReady => { + self.inner = DialerSelectParState::SendRequest { sender, protocols }; + return Ok(Async::NotReady) + } + }; + let stream = dialer.into_future(); + self.inner = DialerSelectParState::AwaitResponse { stream, protocols }; + } + DialerSelectParState::AwaitResponse { mut stream, protocols } => { + let (m, d) = match stream.poll() { + Ok(Async::Ready(x)) => x, + Ok(Async::NotReady) => { + self.inner = DialerSelectParState::AwaitResponse { stream, protocols }; + return Ok(Async::NotReady) + } + Err((e, _)) => return Err(ProtocolChoiceError::from(e)) + }; + trace!("protocols list response: {:?}", m); + let list = match m { + Some(ListenerToDialerMessage::ProtocolsListResponse { list }) => list, + _ => return Err(ProtocolChoiceError::UnexpectedMessage), + }; + let mut found = None; + for (local_name, mut match_fn, ident) in protocols { + for remote_name in &list { + if match_fn(remote_name, &local_name) { + found = Some((remote_name.clone(), ident)); + break; + } + } + if found.is_some() { + break; + } + } + let (proto_name, proto_val) = found.ok_or(ProtocolChoiceError::NoProtocolFound)?; + trace!("sending {:?}", proto_name); + let sender = d.send(DialerToListenerMessage::ProtocolRequest { + name: proto_name.clone(), + }); + self.inner = DialerSelectParState::SendProtocol { sender, proto_name, proto_val }; + } + DialerSelectParState::SendProtocol { mut sender, proto_name, proto_val } => { + let dialer = match sender.poll()? { + Async::Ready(d) => d, + Async::NotReady => { + self.inner = DialerSelectParState::SendProtocol { + sender, + proto_name, + proto_val + }; + return Ok(Async::NotReady) + } + }; + let stream = dialer.into_future(); + self.inner = DialerSelectParState::AwaitProtocol { + stream, + proto_name, + proto_val + }; + } + DialerSelectParState::AwaitProtocol { mut stream, proto_name, proto_val } => { + let (m, r) = match stream.poll() { + Ok(Async::Ready(x)) => x, + Ok(Async::NotReady) => { + self.inner = DialerSelectParState::AwaitProtocol { + stream, + proto_name, + proto_val + }; + return Ok(Async::NotReady) + } + Err((e, _)) => return Err(ProtocolChoiceError::from(e)) + }; + trace!("received {:?}", m); + match m { + Some(ListenerToDialerMessage::ProtocolAck { ref name }) if name == &proto_name => { + return Ok(Async::Ready((proto_val, r.into_inner()))) + } + _ => return Err(ProtocolChoiceError::UnexpectedMessage) } } - - if found.is_some() { - break; - } + DialerSelectParState::Undefined => + panic!("DialerSelectParState::poll called after completion") } - - let (proto_name, proto_val) = found.ok_or(ProtocolChoiceError::NoProtocolFound)?; - Ok((proto_name, proto_val, dialer)) - }) - .and_then(|(proto_name, proto_val, dialer)| { - trace!("sending {:?}", proto_name); - dialer - .send(DialerToListenerMessage::ProtocolRequest { - name: proto_name.clone(), - }) - .from_err() - .map(|dialer| (proto_name, proto_val, dialer)) - }) - .and_then(|(proto_name, proto_val, dialer)| { - dialer - .into_future() - .map(|(msg, rest)| (proto_name, proto_val, msg, rest)) - .map_err(|(err, _)| err.into()) - }) - .and_then(|(proto_name, proto_val, msg, dialer)| { - trace!("received {:?}", msg); - match msg { - Some(ListenerToDialerMessage::ProtocolAck { ref name }) if name == &proto_name => { - Ok((proto_val, dialer.into_inner())) - } - _ => Err(ProtocolChoiceError::UnexpectedMessage), - } - }) + } + } } + + diff --git a/misc/multistream-select/src/lib.rs b/misc/multistream-select/src/lib.rs index b747d46c..a5eaa43e 100644 --- a/misc/multistream-select/src/lib.rs +++ b/misc/multistream-select/src/lib.rs @@ -113,6 +113,7 @@ //! ``` extern crate bytes; +#[macro_use] extern crate futures; #[macro_use] extern crate log; @@ -128,6 +129,6 @@ mod tests; pub mod protocol; -pub use self::dialer_select::dialer_select_proto; +pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture}; pub use self::error::ProtocolChoiceError; -pub use self::listener_select::listener_select_proto; +pub use self::listener_select::{listener_select_proto, ListenerSelectFuture}; diff --git a/misc/multistream-select/src/listener_select.rs b/misc/multistream-select/src/listener_select.rs index e97674b4..a910ad4a 100644 --- a/misc/multistream-select/src/listener_select.rs +++ b/misc/multistream-select/src/listener_select.rs @@ -22,14 +22,11 @@ //! `multistream-select` for the listener. use bytes::Bytes; -use futures::future::{err, loop_fn, Loop, Either}; -use futures::{Future, Sink, Stream}; -use ProtocolChoiceError; - -use protocol::DialerToListenerMessage; -use protocol::Listener; -use protocol::ListenerToDialerMessage; +use futures::{prelude::*, sink, stream::StreamFuture}; +use protocol::{DialerToListenerMessage, Listener, ListenerFuture, ListenerToDialerMessage}; +use std::mem; use tokio_io::{AsyncRead, AsyncWrite}; +use ProtocolChoiceError; /// Helps selecting a protocol amongst the ones supported. /// @@ -45,61 +42,122 @@ use tokio_io::{AsyncRead, AsyncWrite}; /// /// On success, returns the socket and the identifier of the chosen protocol (of type `P`). The /// socket now uses this protocol. -pub fn listener_select_proto( - inner: R, - protocols: I, -) -> impl Future +pub fn listener_select_proto(inner: R, protocols: I) -> ListenerSelectFuture where R: AsyncRead + AsyncWrite, I: Iterator + Clone, M: FnMut(&Bytes, &Bytes) -> bool, { - Listener::new(inner).from_err().and_then(move |listener| { - loop_fn(listener, move |listener| { - let protocols = protocols.clone(); + ListenerSelectFuture { + inner: ListenerSelectState::AwaitListener { listener_fut: Listener::new(inner), protocols } + } +} - listener - .into_future() - .map_err(|(e, _)| e.into()) - .and_then(move |(message, listener)| match message { - Some(DialerToListenerMessage::ProtocolsListRequest) => { - let msg = ListenerToDialerMessage::ProtocolsListResponse { - list: protocols.map(|(p, _, _)| p).collect(), - }; - trace!("protocols list response: {:?}", msg); - let fut = listener - .send(msg) - .from_err() - .map(move |listener| (None, listener)); - Either::A(Either::A(fut)) - } - Some(DialerToListenerMessage::ProtocolRequest { name }) => { - let mut outcome = None; - let mut send_back = ListenerToDialerMessage::NotAvailable; - for (supported, mut matches, value) in protocols { - if matches(&name, &supported) { - send_back = - ListenerToDialerMessage::ProtocolAck { name: name.clone() }; - outcome = Some(value); - break; +/// Future, returned by `listener_select_proto` which selects a protocol among the ones supported. +pub struct ListenerSelectFuture { + inner: ListenerSelectState +} + +enum ListenerSelectState { + AwaitListener { + listener_fut: ListenerFuture, + protocols: I + }, + Incoming { + stream: StreamFuture>, + protocols: I + }, + Outgoing { + sender: sink::Send>, + protocols: I, + outcome: Option

+ }, + Undefined +} + +impl Future for ListenerSelectFuture +where + I: Iterator + Clone, + M: FnMut(&Bytes, &Bytes) -> bool, + R: AsyncRead + AsyncWrite, +{ + type Item = (P, R); + type Error = ProtocolChoiceError; + + fn poll(&mut self) -> Poll { + loop { + match mem::replace(&mut self.inner, ListenerSelectState::Undefined) { + ListenerSelectState::AwaitListener { mut listener_fut, protocols } => { + let listener = match listener_fut.poll()? { + Async::Ready(l) => l, + Async::NotReady => { + self.inner = ListenerSelectState::AwaitListener { listener_fut, protocols }; + return Ok(Async::NotReady) + } + }; + let stream = listener.into_future(); + self.inner = ListenerSelectState::Incoming { stream, protocols }; + } + ListenerSelectState::Incoming { mut stream, protocols } => { + let (msg, listener) = match stream.poll() { + Ok(Async::Ready(x)) => x, + Ok(Async::NotReady) => { + self.inner = ListenerSelectState::Incoming { stream, protocols }; + return Ok(Async::NotReady) + } + Err((e, _)) => return Err(ProtocolChoiceError::from(e)) + }; + match msg { + Some(DialerToListenerMessage::ProtocolsListRequest) => { + let msg = ListenerToDialerMessage::ProtocolsListResponse { + list: protocols.clone().map(|(p, _, _)| p).collect(), + }; + trace!("protocols list response: {:?}", msg); + let sender = listener.send(msg); + self.inner = ListenerSelectState::Outgoing { + sender, + protocols, + outcome: None } } - trace!("requested: {:?}, response: {:?}", name, send_back); - let fut = listener - .send(send_back) - .from_err() - .map(move |listener| (outcome, listener)); - Either::A(Either::B(fut)) + Some(DialerToListenerMessage::ProtocolRequest { name }) => { + let mut outcome = None; + let mut send_back = ListenerToDialerMessage::NotAvailable; + for (supported, mut matches, value) in protocols.clone() { + if matches(&name, &supported) { + send_back = ListenerToDialerMessage::ProtocolAck {name: name.clone()}; + outcome = Some(value); + break; + } + } + trace!("requested: {:?}, response: {:?}", name, send_back); + let sender = listener.send(send_back); + self.inner = ListenerSelectState::Outgoing { sender, protocols, outcome } + } + None => { + debug!("no protocol request received"); + return Err(ProtocolChoiceError::NoProtocolFound) + } } - None => { - debug!("no protocol request received"); - Either::B(err(ProtocolChoiceError::NoProtocolFound)) + } + ListenerSelectState::Outgoing { mut sender, protocols, outcome } => { + let listener = match sender.poll()? { + Async::Ready(l) => l, + Async::NotReady => { + self.inner = ListenerSelectState::Outgoing { sender, protocols, outcome }; + return Ok(Async::NotReady) + } + }; + if let Some(p) = outcome { + return Ok(Async::Ready((p, listener.into_inner()))) + } else { + let stream = listener.into_future(); + self.inner = ListenerSelectState::Incoming { stream, protocols } } - }) - .map(|(outcome, listener): (_, Listener)| match outcome { - Some(outcome) => Loop::Break((outcome, listener.into_inner())), - None => Loop::Continue(listener), - }) - }) - }) + } + ListenerSelectState::Undefined => + panic!("ListenerSelectState::poll called after completion") + } + } + } } diff --git a/misc/multistream-select/src/protocol/dialer.rs b/misc/multistream-select/src/protocol/dialer.rs index 8809c0b3..896095ee 100644 --- a/misc/multistream-select/src/protocol/dialer.rs +++ b/misc/multistream-select/src/protocol/dialer.rs @@ -21,7 +21,7 @@ //! Contains the `Dialer` wrapper, which allows raw communications with a listener. use bytes::{Bytes, BytesMut}; -use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; +use futures::{prelude::*, sink, Async, AsyncSink, StartSend}; use length_delimited::LengthDelimitedFramedRead; use protocol::DialerToListenerMessage; use protocol::ListenerToDialerMessage; @@ -32,6 +32,7 @@ use tokio_io::codec::length_delimited::FramedWrite as LengthDelimitedFramedWrite use tokio_io::{AsyncRead, AsyncWrite}; use unsigned_varint::decode; + /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the dialer's side. Produces and /// accepts messages. pub struct Dialer { @@ -45,19 +46,12 @@ where { /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the /// future returns a `Dialer`. - pub fn new(inner: R) -> impl Future, Error = MultistreamSelectError> { - let write = LengthDelimitedBuilder::new() - .length_field_length(1) - .new_write(inner); - let inner = LengthDelimitedFramedRead::new(write); - - inner - .send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF)) - .from_err() - .map(|inner| Dialer { - inner, - handshake_finished: false, - }) + pub fn new(inner: R) -> DialerFuture { + let write = LengthDelimitedBuilder::new().length_field_length(1).new_write(inner); + let sender = LengthDelimitedFramedRead::new(write); + DialerFuture { + inner: sender.send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF)) + } } /// Grants back the socket. Typically used after a `ProtocolAck` has been received. @@ -170,6 +164,22 @@ where } } +/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`. +pub struct DialerFuture { + inner: sink::Send>> +} + +impl Future for DialerFuture { + type Item = Dialer; + type Error = MultistreamSelectError; + + fn poll(&mut self) -> Poll { + let inner = try_ready!(self.inner.poll()); + Ok(Async::Ready(Dialer { inner, handshake_finished: false })) + } +} + + #[cfg(test)] mod tests { extern crate tokio_current_thread; diff --git a/misc/multistream-select/src/protocol/listener.rs b/misc/multistream-select/src/protocol/listener.rs index f1a39ae5..ba4a8516 100644 --- a/misc/multistream-select/src/protocol/listener.rs +++ b/misc/multistream-select/src/protocol/listener.rs @@ -21,17 +21,19 @@ //! Contains the `Listener` wrapper, which allows raw communications with a dialer. use bytes::{Bytes, BytesMut}; -use futures::{Async, AsyncSink, Future, Poll, Sink, StartSend, Stream}; +use futures::{Async, AsyncSink, prelude::*, sink, stream::StreamFuture}; use length_delimited::LengthDelimitedFramedRead; use protocol::DialerToListenerMessage; use protocol::ListenerToDialerMessage; use protocol::MultistreamSelectError; use protocol::MULTISTREAM_PROTOCOL_WITH_LF; +use std::mem; use tokio_io::codec::length_delimited::Builder as LengthDelimitedBuilder; use tokio_io::codec::length_delimited::FramedWrite as LengthDelimitedFramedWrite; use tokio_io::{AsyncRead, AsyncWrite}; use unsigned_varint::encode; + /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and /// accepts messages. pub struct Listener { @@ -44,29 +46,14 @@ where { /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the /// future returns a `Listener`. - pub fn new(inner: R) -> impl Future, Error = MultistreamSelectError> { + pub fn new(inner: R) -> ListenerFuture { let write = LengthDelimitedBuilder::new() .length_field_length(1) .new_write(inner); let inner = LengthDelimitedFramedRead::::new(write); - - inner - .into_future() - .map_err(|(e, _)| e.into()) - .and_then(|(msg, rest)| { - if msg.as_ref().map(|b| &b[..]) != Some(MULTISTREAM_PROTOCOL_WITH_LF) { - debug!("failed handshake; received: {:?}", msg); - return Err(MultistreamSelectError::FailedHandshake); - } - Ok(rest) - }) - .and_then(|socket| { - trace!("sending back /multistream/ to finish the handshake"); - socket - .send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF)) - .from_err() - }) - .map(|inner| Listener { inner }) + ListenerFuture { + inner: ListenerFutureState::Await { inner: inner.into_future() } + } } /// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a @@ -178,6 +165,65 @@ where } } + +/// Future, returned by `Listener::new` which performs the handshake and returns +/// the `Listener` if successful. +pub struct ListenerFuture { + inner: ListenerFutureState +} + +enum ListenerFutureState { + Await { + inner: StreamFuture>> + }, + Reply { + sender: sink::Send>> + }, + Undefined +} + +impl Future for ListenerFuture { + type Item = Listener; + type Error = MultistreamSelectError; + + fn poll(&mut self) -> Poll { + loop { + match mem::replace(&mut self.inner, ListenerFutureState::Undefined) { + ListenerFutureState::Await { mut inner } => { + let (msg, socket) = + match inner.poll() { + Ok(Async::Ready(x)) => x, + Ok(Async::NotReady) => { + self.inner = ListenerFutureState::Await { inner }; + return Ok(Async::NotReady) + } + Err((e, _)) => return Err(MultistreamSelectError::from(e)) + }; + if msg.as_ref().map(|b| &b[..]) != Some(MULTISTREAM_PROTOCOL_WITH_LF) { + debug!("failed handshake; received: {:?}", msg); + return Err(MultistreamSelectError::FailedHandshake) + } + trace!("sending back /multistream/ to finish the handshake"); + let sender = socket.send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF)); + self.inner = ListenerFutureState::Reply { sender } + } + ListenerFutureState::Reply { mut sender } => { + let listener = match sender.poll()? { + Async::Ready(x) => x, + Async::NotReady => { + self.inner = ListenerFutureState::Reply { sender }; + return Ok(Async::NotReady) + } + }; + return Ok(Async::Ready(Listener { inner: listener })) + } + ListenerFutureState::Undefined => panic!("ListenerFutureState::poll called after completion") + } + } + } +} + + #[cfg(test)] mod tests { extern crate tokio_current_thread; diff --git a/misc/multistream-select/src/protocol/mod.rs b/misc/multistream-select/src/protocol/mod.rs index c5d97f31..23594128 100644 --- a/misc/multistream-select/src/protocol/mod.rs +++ b/misc/multistream-select/src/protocol/mod.rs @@ -28,9 +28,9 @@ mod listener; const MULTISTREAM_PROTOCOL_WITH_LF: &[u8] = b"/multistream/1.0.0\n"; -pub use self::dialer::Dialer; +pub use self::dialer::{Dialer, DialerFuture}; pub use self::error::MultistreamSelectError; -pub use self::listener::Listener; +pub use self::listener::{Listener, ListenerFuture}; /// Message sent from the dialer to the listener. #[derive(Debug, Clone, PartialEq, Eq)]