Update multistream-select to stable futures (#1484)

* Update multistream-select to stable futures

* Fix intradoc links
This commit is contained in:
Pierre Krieger
2020-03-11 14:49:41 +01:00
committed by GitHub
parent 2084fadd86
commit 31271fc824
11 changed files with 851 additions and 652 deletions

View File

@ -15,7 +15,7 @@ bs58 = "0.3.0"
ed25519-dalek = "1.0.0-pre.3" ed25519-dalek = "1.0.0-pre.3"
either = "1.5" either = "1.5"
fnv = "1.0" fnv = "1.0"
futures = { version = "0.3.1", features = ["compat", "io-compat", "executor", "thread-pool"] } futures = { version = "0.3.1", features = ["executor", "thread-pool"] }
futures-timer = "3" futures-timer = "3"
lazy_static = "1.2" lazy_static = "1.2"
libsecp256k1 = { version = "0.3.1", optional = true } libsecp256k1 = { version = "0.3.1", optional = true }

View File

@ -41,7 +41,7 @@ mod keys_proto {
/// Multi-address re-export. /// Multi-address re-export.
pub use multiaddr; pub use multiaddr;
pub type Negotiated<T> = futures::compat::Compat01As03<multistream_select::Negotiated<futures::compat::Compat<T>>>; pub type Negotiated<T> = multistream_select::Negotiated<T>;
mod peer_id; mod peer_id;
mod translation; mod translation;

View File

@ -20,7 +20,7 @@
use crate::{ConnectedPoint, Negotiated}; use crate::{ConnectedPoint, Negotiated};
use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName}; use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName};
use futures::{future::Either, prelude::*, compat::Compat, compat::Compat01As03, compat::Future01CompatExt}; use futures::{future::Either, prelude::*};
use log::debug; use log::debug;
use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture}; use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture};
use std::{iter, mem, pin::Pin, task::Context, task::Poll}; use std::{iter, mem, pin::Pin, task::Context, task::Poll};
@ -48,7 +48,7 @@ where
U: InboundUpgrade<Negotiated<C>>, U: InboundUpgrade<Negotiated<C>>,
{ {
let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>);
let future = multistream_select::listener_select_proto(Compat::new(conn), iter).compat(); let future = multistream_select::listener_select_proto(conn, iter);
InboundUpgradeApply { InboundUpgradeApply {
inner: InboundUpgradeApplyState::Init { future, upgrade: up } inner: InboundUpgradeApplyState::Init { future, upgrade: up }
} }
@ -61,7 +61,7 @@ where
U: OutboundUpgrade<Negotiated<C>> U: OutboundUpgrade<Negotiated<C>>
{ {
let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>); let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>);
let future = multistream_select::dialer_select_proto(Compat::new(conn), iter, v).compat(); let future = multistream_select::dialer_select_proto(conn, iter, v);
OutboundUpgradeApply { OutboundUpgradeApply {
inner: OutboundUpgradeApplyState::Init { future, upgrade: up } inner: OutboundUpgradeApplyState::Init { future, upgrade: up }
} }
@ -82,7 +82,7 @@ where
U: InboundUpgrade<Negotiated<C>>, U: InboundUpgrade<Negotiated<C>>,
{ {
Init { Init {
future: Compat01As03<ListenerSelectFuture<Compat<C>, NameWrap<U::Info>>>, future: ListenerSelectFuture<C, NameWrap<U::Info>>,
upgrade: U, upgrade: U,
}, },
Upgrade { Upgrade {
@ -117,7 +117,7 @@ where
} }
}; };
self.inner = InboundUpgradeApplyState::Upgrade { self.inner = InboundUpgradeApplyState::Upgrade {
future: Box::pin(upgrade.upgrade_inbound(Compat01As03::new(io), info.0)) future: Box::pin(upgrade.upgrade_inbound(io, info.0))
}; };
} }
InboundUpgradeApplyState::Upgrade { mut future } => { InboundUpgradeApplyState::Upgrade { mut future } => {
@ -158,7 +158,7 @@ where
U: OutboundUpgrade<Negotiated<C>> U: OutboundUpgrade<Negotiated<C>>
{ {
Init { Init {
future: Compat01As03<DialerSelectFuture<Compat<C>, NameWrapIter<<U::InfoIter as IntoIterator>::IntoIter>>>, future: DialerSelectFuture<C, NameWrapIter<<U::InfoIter as IntoIterator>::IntoIter>>,
upgrade: U upgrade: U
}, },
Upgrade { Upgrade {
@ -193,7 +193,7 @@ where
} }
}; };
self.inner = OutboundUpgradeApplyState::Upgrade { self.inner = OutboundUpgradeApplyState::Upgrade {
future: Box::pin(upgrade.upgrade_outbound(Compat01As03::new(connection), info.0)) future: Box::pin(upgrade.upgrade_outbound(connection, info.0))
}; };
} }
OutboundUpgradeApplyState::Upgrade { mut future } => { OutboundUpgradeApplyState::Upgrade { mut future } => {

View File

@ -11,14 +11,14 @@ edition = "2018"
[dependencies] [dependencies]
bytes = "0.5" bytes = "0.5"
futures = "0.1" futures = "0.3"
log = "0.4" log = "0.4"
pin-project = "0.4.8"
smallvec = "1.0" smallvec = "1.0"
tokio-io = "0.1" unsigned-varint = "0.3.2"
unsigned-varint = "0.3"
[dev-dependencies] [dev-dependencies]
tokio = "0.1" async-std = "1.2"
tokio-tcp = "0.1"
quickcheck = "0.9.0" quickcheck = "0.9.0"
rand = "0.7.2" rand = "0.7.2"
rw-stream-sink = "0.2.1"

View File

@ -20,12 +20,11 @@
//! Protocol negotiation strategies for the peer acting as the dialer. //! Protocol negotiation strategies for the peer acting as the dialer.
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
use futures::{future::Either, prelude::*};
use log::debug;
use std::{io, iter, mem, convert::TryFrom};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{Negotiated, NegotiationError}; use crate::{Negotiated, NegotiationError};
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
use futures::{future::Either, prelude::*};
use std::{convert::TryFrom as _, io, iter, mem, pin::Pin, task::{Context, Poll}};
/// Returns a `Future` that negotiates a protocol on the given I/O stream /// Returns a `Future` that negotiates a protocol on the given I/O stream
/// for a peer acting as the _dialer_ (or _initiator_). /// for a peer acting as the _dialer_ (or _initiator_).
@ -60,9 +59,9 @@ where
let iter = protocols.into_iter(); let iter = protocols.into_iter();
// We choose between the "serial" and "parallel" strategies based on the number of protocols. // We choose between the "serial" and "parallel" strategies based on the number of protocols.
if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) { if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
Either::A(dialer_select_proto_serial(inner, iter, version)) Either::Left(dialer_select_proto_serial(inner, iter, version))
} else { } else {
Either::B(dialer_select_proto_parallel(inner, iter, version)) Either::Right(dialer_select_proto_parallel(inner, iter, version))
} }
} }
@ -129,6 +128,7 @@ where
/// A `Future` returned by [`dialer_select_proto_serial`] which negotiates /// A `Future` returned by [`dialer_select_proto_serial`] which negotiates
/// a protocol iteratively by considering one protocol after the other. /// a protocol iteratively by considering one protocol after the other.
#[pin_project::pin_project]
pub struct DialerSelectSeq<R, I> pub struct DialerSelectSeq<R, I>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
@ -155,83 +155,107 @@ where
impl<R, I> Future for DialerSelectSeq<R, I> impl<R, I> Future for DialerSelectSeq<R, I>
where where
R: AsyncRead + AsyncWrite, // The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
// It also makes the implementation considerably easier to write.
R: AsyncRead + AsyncWrite + Unpin,
I: Iterator, I: Iterator,
I::Item: AsRef<[u8]> I::Item: AsRef<[u8]>
{ {
type Item = (I::Item, Negotiated<R>); type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
type Error = NegotiationError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
match mem::replace(&mut self.state, SeqState::Done) { match mem::replace(this.state, SeqState::Done) {
SeqState::SendHeader { mut io } => { SeqState::SendHeader { mut io } => {
if io.start_send(Message::Header(self.version))?.is_not_ready() { match Pin::new(&mut io).poll_ready(cx)? {
self.state = SeqState::SendHeader { io }; Poll::Ready(()) => {},
return Ok(Async::NotReady) Poll::Pending => {
*this.state = SeqState::SendHeader { io };
return Poll::Pending
},
} }
let protocol = self.protocols.next().ok_or(NegotiationError::Failed)?;
self.state = SeqState::SendProtocol { io, protocol }; if let Err(err) = Pin::new(&mut io).start_send(Message::Header(*this.version)) {
return Poll::Ready(Err(From::from(err)));
}
let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
*this.state = SeqState::SendProtocol { io, protocol };
} }
SeqState::SendProtocol { mut io, protocol } => { SeqState::SendProtocol { mut io, protocol } => {
let p = Protocol::try_from(protocol.as_ref())?; match Pin::new(&mut io).poll_ready(cx)? {
if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() { Poll::Ready(()) => {},
self.state = SeqState::SendProtocol { io, protocol }; Poll::Pending => {
return Ok(Async::NotReady) *this.state = SeqState::SendProtocol { io, protocol };
return Poll::Pending
},
} }
debug!("Dialer: Proposed protocol: {}", p);
if self.protocols.peek().is_some() { let p = Protocol::try_from(protocol.as_ref())?;
self.state = SeqState::FlushProtocol { io, protocol } if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
return Poll::Ready(Err(From::from(err)));
}
log::debug!("Dialer: Proposed protocol: {}", p);
if this.protocols.peek().is_some() {
*this.state = SeqState::FlushProtocol { io, protocol }
} else { } else {
match self.version { match this.version {
Version::V1 => self.state = SeqState::FlushProtocol { io, protocol }, Version::V1 => *this.state = SeqState::FlushProtocol { io, protocol },
Version::V1Lazy => { Version::V1Lazy => {
debug!("Dialer: Expecting proposed protocol: {}", p); log::debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p, self.version); let io = Negotiated::expecting(io.into_reader(), p, *this.version);
return Ok(Async::Ready((protocol, io))) return Poll::Ready(Ok((protocol, io)))
} }
} }
} }
} }
SeqState::FlushProtocol { mut io, protocol } => { SeqState::FlushProtocol { mut io, protocol } => {
if io.poll_complete()?.is_not_ready() { match Pin::new(&mut io).poll_flush(cx)? {
self.state = SeqState::FlushProtocol { io, protocol }; Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol },
return Ok(Async::NotReady) Poll::Pending => {
*this.state = SeqState::FlushProtocol { io, protocol };
return Poll::Pending
},
} }
self.state = SeqState::AwaitProtocol { io, protocol }
} }
SeqState::AwaitProtocol { mut io, protocol } => { SeqState::AwaitProtocol { mut io, protocol } => {
let msg = match io.poll()? { let msg = match Pin::new(&mut io).poll_next(cx)? {
Async::NotReady => { Poll::Ready(Some(msg)) => msg,
self.state = SeqState::AwaitProtocol { io, protocol }; Poll::Pending => {
return Ok(Async::NotReady) *this.state = SeqState::AwaitProtocol { io, protocol };
return Poll::Pending
} }
Async::Ready(None) => Poll::Ready(None) =>
return Err(NegotiationError::from( return Poll::Ready(Err(NegotiationError::from(
io::Error::from(io::ErrorKind::UnexpectedEof))), io::Error::from(io::ErrorKind::UnexpectedEof)))),
Async::Ready(Some(msg)) => msg,
}; };
match msg { match msg {
Message::Header(v) if v == self.version => { Message::Header(v) if v == *this.version => {
self.state = SeqState::AwaitProtocol { io, protocol }; *this.state = SeqState::AwaitProtocol { io, protocol };
} }
Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => { Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
debug!("Dialer: Received confirmation for protocol: {}", p); log::debug!("Dialer: Received confirmation for protocol: {}", p);
let (io, remaining) = io.into_inner(); let (io, remaining) = io.into_inner();
let io = Negotiated::completed(io, remaining); let io = Negotiated::completed(io, remaining);
return Ok(Async::Ready((protocol, io))) return Poll::Ready(Ok((protocol, io)));
} }
Message::NotAvailable => { Message::NotAvailable => {
debug!("Dialer: Received rejection of protocol: {}", log::debug!("Dialer: Received rejection of protocol: {}",
String::from_utf8_lossy(protocol.as_ref())); String::from_utf8_lossy(protocol.as_ref()));
let protocol = self.protocols.next() let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
.ok_or(NegotiationError::Failed)?; *this.state = SeqState::SendProtocol { io, protocol }
self.state = SeqState::SendProtocol { io, protocol }
} }
_ => return Err(ProtocolError::InvalidMessage.into()) _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
} }
} }
SeqState::Done => panic!("SeqState::poll called after completion") SeqState::Done => panic!("SeqState::poll called after completion")
} }
} }
@ -241,6 +265,7 @@ where
/// A `Future` returned by [`dialer_select_proto_parallel`] which negotiates /// A `Future` returned by [`dialer_select_proto_parallel`] which negotiates
/// a protocol selectively by considering all supported protocols of the remote /// a protocol selectively by considering all supported protocols of the remote
/// "in parallel". /// "in parallel".
#[pin_project::pin_project]
pub struct DialerSelectPar<R, I> pub struct DialerSelectPar<R, I>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
@ -267,76 +292,110 @@ where
impl<R, I> Future for DialerSelectPar<R, I> impl<R, I> Future for DialerSelectPar<R, I>
where where
R: AsyncRead + AsyncWrite, // The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
// It also makes the implementation considerably easier to write.
R: AsyncRead + AsyncWrite + Unpin,
I: Iterator, I: Iterator,
I::Item: AsRef<[u8]> I::Item: AsRef<[u8]>
{ {
type Item = (I::Item, Negotiated<R>); type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
type Error = NegotiationError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
match mem::replace(&mut self.state, ParState::Done) { match mem::replace(this.state, ParState::Done) {
ParState::SendHeader { mut io } => { ParState::SendHeader { mut io } => {
if io.start_send(Message::Header(self.version))?.is_not_ready() { match Pin::new(&mut io).poll_ready(cx)? {
self.state = ParState::SendHeader { io }; Poll::Ready(()) => {},
return Ok(Async::NotReady) Poll::Pending => {
*this.state = ParState::SendHeader { io };
return Poll::Pending
},
} }
self.state = ParState::SendProtocolsRequest { io };
if let Err(err) = Pin::new(&mut io).start_send(Message::Header(*this.version)) {
return Poll::Ready(Err(From::from(err)));
}
*this.state = ParState::SendProtocolsRequest { io };
} }
ParState::SendProtocolsRequest { mut io } => { ParState::SendProtocolsRequest { mut io } => {
if io.start_send(Message::ListProtocols)?.is_not_ready() { match Pin::new(&mut io).poll_ready(cx)? {
self.state = ParState::SendProtocolsRequest { io }; Poll::Ready(()) => {},
return Ok(Async::NotReady) Poll::Pending => {
*this.state = ParState::SendProtocolsRequest { io };
return Poll::Pending
},
} }
debug!("Dialer: Requested supported protocols.");
self.state = ParState::Flush { io } if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) {
return Poll::Ready(Err(From::from(err)));
}
log::debug!("Dialer: Requested supported protocols.");
*this.state = ParState::Flush { io }
} }
ParState::Flush { mut io } => { ParState::Flush { mut io } => {
if io.poll_complete()?.is_not_ready() { match Pin::new(&mut io).poll_flush(cx)? {
self.state = ParState::Flush { io }; Poll::Ready(()) => *this.state = ParState::RecvProtocols { io },
return Ok(Async::NotReady) Poll::Pending => {
*this.state = ParState::Flush { io };
return Poll::Pending
},
} }
self.state = ParState::RecvProtocols { io }
} }
ParState::RecvProtocols { mut io } => { ParState::RecvProtocols { mut io } => {
let msg = match io.poll()? { let msg = match Pin::new(&mut io).poll_next(cx)? {
Async::NotReady => { Poll::Ready(Some(msg)) => msg,
self.state = ParState::RecvProtocols { io }; Poll::Pending => {
return Ok(Async::NotReady) *this.state = ParState::RecvProtocols { io };
return Poll::Pending
} }
Async::Ready(None) => Poll::Ready(None) =>
return Err(NegotiationError::from( return Poll::Ready(Err(NegotiationError::from(
io::Error::from(io::ErrorKind::UnexpectedEof))), io::Error::from(io::ErrorKind::UnexpectedEof)))),
Async::Ready(Some(msg)) => msg,
}; };
match &msg { match &msg {
Message::Header(v) if v == &self.version => { Message::Header(v) if v == this.version => {
self.state = ParState::RecvProtocols { io } *this.state = ParState::RecvProtocols { io }
} }
Message::Protocols(supported) => { Message::Protocols(supported) => {
let protocol = self.protocols.by_ref() let protocol = this.protocols.by_ref()
.find(|p| supported.iter().any(|s| .find(|p| supported.iter().any(|s|
s.as_ref() == p.as_ref())) s.as_ref() == p.as_ref()))
.ok_or(NegotiationError::Failed)?; .ok_or(NegotiationError::Failed)?;
debug!("Dialer: Found supported protocol: {}", log::debug!("Dialer: Found supported protocol: {}",
String::from_utf8_lossy(protocol.as_ref())); String::from_utf8_lossy(protocol.as_ref()));
self.state = ParState::SendProtocol { io, protocol }; *this.state = ParState::SendProtocol { io, protocol };
} }
_ => return Err(ProtocolError::InvalidMessage.into()) _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
} }
} }
ParState::SendProtocol { mut io, protocol } => { ParState::SendProtocol { mut io, protocol } => {
let p = Protocol::try_from(protocol.as_ref())?; match Pin::new(&mut io).poll_ready(cx)? {
if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() { Poll::Ready(()) => {},
self.state = ParState::SendProtocol { io, protocol }; Poll::Pending => {
return Ok(Async::NotReady) *this.state = ParState::SendProtocol { io, protocol };
return Poll::Pending
},
} }
debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p, self.version); let p = Protocol::try_from(protocol.as_ref())?;
return Ok(Async::Ready((protocol, io))) if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
return Poll::Ready(Err(From::from(err)));
}
log::debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p, *this.version);
return Poll::Ready(Ok((protocol, io)))
} }
ParState::Done => panic!("ParState::poll called after completion") ParState::Done => panic!("ParState::poll called after completion")
} }
} }

View File

@ -18,11 +18,9 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
use bytes::{Bytes, BytesMut, Buf, BufMut}; use bytes::{Bytes, BytesMut, Buf as _, BufMut as _};
use futures::{try_ready, Async, Poll, Sink, StartSend, Stream, AsyncSink}; use futures::{prelude::*, io::IoSlice};
use std::{io, u16}; use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint as uvi;
const MAX_LEN_BYTES: u16 = 2; const MAX_LEN_BYTES: u16 = 2;
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1; const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
@ -34,9 +32,11 @@ const DEFAULT_BUFFER_SIZE: usize = 64;
/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint /// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
/// frame length). Frames mostly consist in a short protocol name, which is highly /// frame length). Frames mostly consist in a short protocol name, which is highly
/// unlikely to be more than 16KiB long. /// unlikely to be more than 16KiB long.
#[pin_project::pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct LengthDelimited<R> { pub struct LengthDelimited<R> {
/// The inner I/O resource. /// The inner I/O resource.
#[pin]
inner: R, inner: R,
/// Read buffer for a single incoming unsigned-varint length-delimited frame. /// Read buffer for a single incoming unsigned-varint length-delimited frame.
read_buffer: BytesMut, read_buffer: BytesMut,
@ -76,20 +76,7 @@ impl<R> LengthDelimited<R> {
} }
} }
/// Returns a reference to the underlying I/O stream. /// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream
pub fn inner_ref(&self) -> &R {
&self.inner
}
/// Returns a mutable reference to the underlying I/O stream.
///
/// > **Note**: Care should be taken to not tamper with the underlying stream of data
/// > coming in, as it may corrupt the stream of frames.
pub fn inner_mut(&mut self) -> &mut R {
&mut self.inner
}
/// Drops the `LengthDelimited` resource, yielding the underlying I/O stream
/// together with the remaining write buffer containing the uvi-framed data /// together with the remaining write buffer containing the uvi-framed data
/// that has not yet been written to the underlying I/O stream. /// that has not yet been written to the underlying I/O stream.
/// ///
@ -107,7 +94,7 @@ impl<R> LengthDelimited<R> {
(self.inner, self.write_buffer) (self.inner, self.write_buffer)
} }
/// Converts the `LengthDelimited` into a `LengthDelimitedReader`, dropping the /// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the
/// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying /// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying
/// I/O stream. /// I/O stream.
/// ///
@ -121,25 +108,29 @@ impl<R> LengthDelimited<R> {
/// Writes all buffered frame data to the underlying I/O stream, /// Writes all buffered frame data to the underlying I/O stream,
/// _without flushing it_. /// _without flushing it_.
/// ///
/// After this method returns `Async::Ready`, the write buffer of frames /// After this method returns `Poll::Ready`, the write buffer of frames
/// submitted to the `Sink` is guaranteed to be empty. /// submitted to the `Sink` is guaranteed to be empty.
pub fn poll_write_buffer(&mut self) -> Poll<(), io::Error> pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context)
-> Poll<Result<(), io::Error>>
where where
R: AsyncWrite R: AsyncWrite
{ {
while !self.write_buffer.is_empty() { let mut this = self.project();
let n = try_ready!(self.inner.poll_write(&self.write_buffer));
if n == 0 { while !this.write_buffer.is_empty() {
return Err(io::Error::new( match this.inner.as_mut().poll_write(cx, &this.write_buffer) {
io::ErrorKind::WriteZero, Poll::Pending => return Poll::Pending,
"Failed to write buffered frame.")) Poll::Ready(Ok(0)) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"Failed to write buffered frame.")))
}
Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
} }
self.write_buffer.advance(n);
} }
Ok(Async::Ready(())) Poll::Ready(Ok(()))
} }
} }
@ -147,72 +138,67 @@ impl<R> Stream for LengthDelimited<R>
where where
R: AsyncRead R: AsyncRead
{ {
type Item = Bytes; type Item = Result<Bytes, io::Error>;
type Error = io::Error;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop { loop {
match &mut self.read_state { match this.read_state {
ReadState::ReadLength { buf, pos } => { ReadState::ReadLength { buf, pos } => {
match self.inner.read(&mut buf[*pos .. *pos + 1]) { match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) {
Ok(0) => { Poll::Ready(Ok(0)) => {
if *pos == 0 { if *pos == 0 {
return Ok(Async::Ready(None)); return Poll::Ready(None);
} else { } else {
return Err(io::ErrorKind::UnexpectedEof.into()); return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
} }
} }
Ok(n) => { Poll::Ready(Ok(n)) => {
debug_assert_eq!(n, 1); debug_assert_eq!(n, 1);
*pos += n; *pos += n;
} }
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
return Ok(Async::NotReady); Poll::Pending => return Poll::Pending,
}
Err(err) => {
return Err(err);
}
}; };
if (buf[*pos - 1] & 0x80) == 0 { if (buf[*pos - 1] & 0x80) == 0 {
// MSB is not set, indicating the end of the length prefix. // MSB is not set, indicating the end of the length prefix.
let (len, _) = uvi::decode::u16(buf).map_err(|e| { let (len, _) = unsigned_varint::decode::u16(buf)
log::debug!("invalid length prefix: {}", e); .map_err(|e| {
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") log::debug!("invalid length prefix: {}", e);
})?; io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;
if len >= 1 { if len >= 1 {
self.read_state = ReadState::ReadData { len, pos: 0 }; *this.read_state = ReadState::ReadData { len, pos: 0 };
self.read_buffer.resize(len as usize, 0); this.read_buffer.resize(len as usize, 0);
} else { } else {
debug_assert_eq!(len, 0); debug_assert_eq!(len, 0);
self.read_state = ReadState::default(); *this.read_state = ReadState::default();
return Ok(Async::Ready(Some(Bytes::new()))); return Poll::Ready(Some(Ok(Bytes::new())));
} }
} else if *pos == MAX_LEN_BYTES as usize { } else if *pos == MAX_LEN_BYTES as usize {
// MSB signals more length bytes but we have already read the maximum. // MSB signals more length bytes but we have already read the maximum.
// See the module documentation about the max frame len. // See the module documentation about the max frame len.
return Err(io::Error::new( return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::InvalidData, io::ErrorKind::InvalidData,
"Maximum frame length exceeded")); "Maximum frame length exceeded"))));
} }
} }
ReadState::ReadData { len, pos } => { ReadState::ReadData { len, pos } => {
match self.inner.read(&mut self.read_buffer[*pos..]) { match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()), Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
Ok(n) => *pos += n, Poll::Ready(Ok(n)) => *pos += n,
Err(err) => Poll::Pending => return Poll::Pending,
if err.kind() == io::ErrorKind::WouldBlock { Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
return Ok(Async::NotReady)
} else {
return Err(err)
}
}; };
if *pos == *len as usize { if *pos == *len as usize {
// Finished reading the frame. // Finished reading the frame.
let frame = self.read_buffer.split_off(0).freeze(); let frame = this.read_buffer.split_off(0).freeze();
self.read_state = ReadState::default(); *this.read_state = ReadState::default();
return Ok(Async::Ready(Some(frame))); return Poll::Ready(Some(Ok(frame)));
} }
} }
} }
@ -220,58 +206,87 @@ where
} }
} }
impl<R> Sink for LengthDelimited<R> impl<R> Sink<Bytes> for LengthDelimited<R>
where where
R: AsyncWrite, R: AsyncWrite,
{ {
type SinkItem = Bytes; type Error = io::Error;
type SinkError = io::Error;
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
// Use the maximum frame length also as a (soft) upper limit // Use the maximum frame length also as a (soft) upper limit
// for the entire write buffer. The actual (hard) limit is thus // for the entire write buffer. The actual (hard) limit is thus
// implied to be roughly 2 * MAX_FRAME_SIZE. // implied to be roughly 2 * MAX_FRAME_SIZE.
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize { if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
self.poll_complete()?; match self.as_mut().poll_write_buffer(cx) {
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize { Poll::Ready(Ok(())) => {},
return Ok(AsyncSink::NotReady(msg)) Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
} }
debug_assert!(self.as_mut().project().write_buffer.is_empty());
} }
let len = msg.len() as u16; Poll::Ready(Ok(()))
if len > MAX_FRAME_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame size exceeded."))
}
let mut uvi_buf = uvi::encode::u16_buffer();
let uvi_len = uvi::encode::u16(len, &mut uvi_buf);
self.write_buffer.reserve(len as usize + uvi_len.len());
self.write_buffer.put(uvi_len);
self.write_buffer.put(msg);
Ok(AsyncSink::Ready)
} }
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
let this = self.project();
let len = match u16::try_from(item.len()) {
Ok(len) if len <= MAX_FRAME_SIZE => len,
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame size exceeded."))
}
};
let mut uvi_buf = unsigned_varint::encode::u16_buffer();
let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
this.write_buffer.reserve(len as usize + uvi_len.len());
this.write_buffer.put(uvi_len);
this.write_buffer.put(item);
Ok(())
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
// Write all buffered frame data to the underlying I/O stream. // Write all buffered frame data to the underlying I/O stream.
try_ready!(self.poll_write_buffer()); match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
let this = self.project();
debug_assert!(this.write_buffer.is_empty());
// Flush the underlying I/O stream. // Flush the underlying I/O stream.
try_ready!(self.inner.poll_flush()); this.inner.poll_flush(cx)
return Ok(Async::Ready(()));
} }
fn close(&mut self) -> Poll<(), Self::SinkError> { fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
try_ready!(self.poll_complete()); // Write all buffered frame data to the underlying I/O stream.
Ok(self.inner.shutdown()?) match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
let this = self.project();
debug_assert!(this.write_buffer.is_empty());
// Close the underlying I/O stream.
this.inner.poll_close(cx)
} }
} }
/// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited /// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited
/// frames on an underlying I/O resource combined with direct `AsyncWrite` access. /// frames on an underlying I/O resource combined with direct `AsyncWrite` access.
#[pin_project::pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct LengthDelimitedReader<R> { pub struct LengthDelimitedReader<R> {
#[pin]
inner: LengthDelimited<R> inner: LengthDelimited<R>
} }
@ -284,58 +299,23 @@ impl<R> LengthDelimitedReader<R> {
/// # Panic /// # Panic
/// ///
/// Will panic if called while there is data in the read or write buffer. /// Will panic if called while there is data in the read or write buffer.
/// The read buffer is guaranteed to be empty whenever [`Stream::poll`] yields /// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`]
/// a new `Message`. The write buffer is guaranteed to be empty whenever /// yield a new `Message`. The write buffer is guaranteed to be empty whenever
/// [`LengthDelimited::poll_write_buffer`] yields [`Async::Ready`] or after /// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after
/// the [`Sink`] has been completely flushed via [`Sink::poll_complete`]. /// the [`Sink`] has been completely flushed via [`Sink::poll_flush`].
pub fn into_inner(self) -> (R, BytesMut) { pub fn into_inner(self) -> (R, BytesMut) {
self.inner.into_inner() self.inner.into_inner()
} }
/// Returns a reference to the underlying I/O stream.
pub fn inner_ref(&self) -> &R {
self.inner.inner_ref()
}
/// Returns a mutable reference to the underlying I/O stream.
///
/// > **Note**: Care should be taken to not tamper with the underlying stream of data
/// > coming in, as it may corrupt the stream of frames.
pub fn inner_mut(&mut self) -> &mut R {
self.inner.inner_mut()
}
} }
impl<R> Stream for LengthDelimitedReader<R> impl<R> Stream for LengthDelimitedReader<R>
where where
R: AsyncRead R: AsyncRead
{ {
type Item = Bytes; type Item = Result<Bytes, io::Error>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.inner.poll() self.project().inner.poll_next(cx)
}
}
impl<R> io::Write for LengthDelimitedReader<R>
where
R: AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
while !self.inner.write_buffer.is_empty() {
if self.inner.poll_write_buffer()?.is_not_ready() {
return Err(io::ErrorKind::WouldBlock.into())
}
}
self.inner_mut().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
match self.inner.poll_complete()? {
Async::Ready(()) => Ok(()),
Async::NotReady => Err(io::ErrorKind::WouldBlock.into())
}
} }
} }
@ -343,23 +323,62 @@ impl<R> AsyncWrite for LengthDelimitedReader<R>
where where
R: AsyncWrite R: AsyncWrite
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8])
try_ready!(self.inner.poll_complete()); -> Poll<Result<usize, io::Error>>
self.inner_mut().shutdown() {
// `this` here designates the `LengthDelimited`.
let mut this = self.project().inner;
// We need to flush any data previously written with the `LengthDelimited`.
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
debug_assert!(this.write_buffer.is_empty());
this.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_close(cx)
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice])
-> Poll<Result<usize, io::Error>>
{
// `this` here designates the `LengthDelimited`.
let mut this = self.project().inner;
// We need to flush any data previously written with the `LengthDelimited`.
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
debug_assert!(this.write_buffer.is_empty());
this.project().inner.poll_write_vectored(cx, bufs)
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use futures::{Future, Stream};
use crate::length_delimited::LengthDelimited; use crate::length_delimited::LengthDelimited;
use std::io::{Cursor, ErrorKind}; use async_std::net::{TcpListener, TcpStream};
use futures::{prelude::*, io::Cursor};
use quickcheck::*;
use std::io::ErrorKind;
#[test] #[test]
fn basic_read() { fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4]; let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap(); let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
} }
@ -367,7 +386,7 @@ mod tests {
fn basic_read_two() { fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7]; let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap(); let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
} }
@ -378,13 +397,10 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>(); let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8]; let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter()); data.extend(frame.clone().into_iter());
let framed = LengthDelimited::new(Cursor::new(data)); let mut framed = LengthDelimited::new(Cursor::new(data));
let recved = framed let recved = futures::executor::block_on(async move {
.into_future() framed.next().await
.map(|(m, _)| m) }).unwrap();
.map_err(|_| ())
.wait()
.unwrap();
assert_eq!(recved.unwrap(), frame); assert_eq!(recved.unwrap(), frame);
} }
@ -392,12 +408,10 @@ mod tests {
fn packet_len_too_long() { fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1]; let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0)); data.extend((0..16513).map(|_| 0));
let framed = LengthDelimited::new(Cursor::new(data)); let mut framed = LengthDelimited::new(Cursor::new(data));
let recved = framed let recved = futures::executor::block_on(async move {
.into_future() framed.next().await.unwrap()
.map(|(m, _)| m) });
.map_err(|(err, _)| err)
.wait();
if let Err(io_err) = recved { if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::InvalidData) assert_eq!(io_err.kind(), ErrorKind::InvalidData)
@ -410,7 +424,7 @@ mod tests {
fn empty_frames() { fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7]; let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap(); let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
assert_eq!( assert_eq!(
recved, recved,
vec![ vec![
@ -427,7 +441,7 @@ mod tests {
fn unexpected_eof_in_len() { fn unexpected_eof_in_len() {
let data = vec![0x89]; let data = vec![0x89];
let framed = LengthDelimited::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait(); let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
if let Err(io_err) = recved { if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else { } else {
@ -439,7 +453,7 @@ mod tests {
fn unexpected_eof_in_data() { fn unexpected_eof_in_data() {
let data = vec![5]; let data = vec![5];
let framed = LengthDelimited::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait(); let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
if let Err(io_err) = recved { if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else { } else {
@ -451,12 +465,54 @@ mod tests {
fn unexpected_eof_in_data2() { fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7]; let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait(); let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
if let Err(io_err) = recved { if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else { } else {
panic!() panic!()
} }
} }
}
#[test]
fn writing_reading() {
fn prop(frames: Vec<Vec<u8>>) -> TestResult {
async_std::task::block_on(async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap();
let expected_frames = frames.clone();
let server = async_std::task::spawn(async move {
let socket = listener.accept().await.unwrap().0;
let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
let mut buf = vec![0u8; 0];
for expected in expected_frames {
if expected.is_empty() {
continue;
}
if buf.len() < expected.len() {
buf.resize(expected.len(), 0);
}
let n = connec.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], &expected[..]);
}
});
let client = async_std::task::spawn(async move {
let socket = TcpStream::connect(&listener_addr).await.unwrap();
let mut connec = LengthDelimited::new(socket);
for frame in frames {
connec.send(From::from(frame)).await.unwrap();
}
});
server.await;
client.await;
});
TestResult::passed()
}
quickcheck(prop as fn(_) -> _)
}
}

View File

@ -77,26 +77,19 @@
//! //!
//! ```no_run //! ```no_run
//! # fn main() { //! # fn main() {
//! use bytes::Bytes; //! use async_std::net::TcpStream;
//! use multistream_select::{dialer_select_proto, Version}; //! use multistream_select::{dialer_select_proto, Version};
//! use futures::{Future, Sink, Stream}; //! use futures::prelude::*;
//! use tokio_tcp::TcpStream;
//! use tokio::runtime::current_thread::Runtime;
//! //!
//! #[derive(Debug, Copy, Clone)] //! async_std::task::block_on(async move {
//! enum MyProto { Echo, Hello } //! let socket = TcpStream::connect("127.0.0.1:10333").await.unwrap();
//! //!
//! let client = TcpStream::connect(&"127.0.0.1:10333".parse().unwrap()) //! let protos = vec![b"/echo/1.0.0", b"/echo/2.5.0"];
//! .from_err() //! let (protocol, _io) = dialer_select_proto(socket, protos, Version::V1).await.unwrap();
//! .and_then(move |io| {
//! let protos = vec![b"/echo/1.0.0", b"/echo/2.5.0"];
//! dialer_select_proto(io, protos, Version::V1)
//! })
//! .map(|(protocol, _io)| protocol);
//! //!
//! let mut rt = Runtime::new().unwrap(); //! println!("Negotiated protocol: {:?}", protocol);
//! let protocol = rt.block_on(client).expect("failed to find a protocol"); //! // You can now use `_io` to communicate with the remote.
//! println!("Negotiated protocol: {:?}", protocol); //! });
//! # } //! # }
//! ``` //! ```
//! //!

View File

@ -21,13 +21,12 @@
//! Protocol negotiation strategies for the peer acting as the listener //! Protocol negotiation strategies for the peer acting as the listener
//! in a multistream-select protocol negotiation. //! in a multistream-select protocol negotiation.
use futures::prelude::*;
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
use log::{debug, warn};
use smallvec::SmallVec;
use std::{io, iter::FromIterator, mem, convert::TryFrom};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{Negotiated, NegotiationError}; use crate::{Negotiated, NegotiationError};
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
use futures::prelude::*;
use smallvec::SmallVec;
use std::{convert::TryFrom as _, io, iter::FromIterator, mem, pin::Pin, task::{Context, Poll}};
/// Returns a `Future` that negotiates a protocol on the given I/O stream /// Returns a `Future` that negotiates a protocol on the given I/O stream
/// for a peer acting as the _listener_ (or _responder_). /// for a peer acting as the _listener_ (or _responder_).
@ -49,7 +48,7 @@ where
match Protocol::try_from(n.as_ref()) { match Protocol::try_from(n.as_ref()) {
Ok(p) => Some((n, p)), Ok(p) => Some((n, p)),
Err(e) => { Err(e) => {
warn!("Listener: Ignoring invalid protocol: {} due to {}", log::warn!("Listener: Ignoring invalid protocol: {} due to {}",
String::from_utf8_lossy(n.as_ref()), e); String::from_utf8_lossy(n.as_ref()), e);
None None
} }
@ -64,6 +63,7 @@ where
/// The `Future` returned by [`listener_select_proto`] that performs a /// The `Future` returned by [`listener_select_proto`] that performs a
/// multistream-select protocol negotiation on an underlying I/O stream. /// multistream-select protocol negotiation on an underlying I/O stream.
#[pin_project::pin_project]
pub struct ListenerSelectFuture<R, N> pub struct ListenerSelectFuture<R, N>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
@ -94,64 +94,80 @@ where
impl<R, N> Future for ListenerSelectFuture<R, N> impl<R, N> Future for ListenerSelectFuture<R, N>
where where
R: AsyncRead + AsyncWrite, // The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
// It also makes the implementation considerably easier to write.
R: AsyncRead + AsyncWrite + Unpin,
N: AsRef<[u8]> + Clone N: AsRef<[u8]> + Clone
{ {
type Item = (N, Negotiated<R>); type Output = Result<(N, Negotiated<R>), NegotiationError>;
type Error = NegotiationError;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
match mem::replace(&mut self.state, State::Done) { match mem::replace(this.state, State::Done) {
State::RecvHeader { mut io } => { State::RecvHeader { mut io } => {
match io.poll()? { match io.poll_next_unpin(cx) {
Async::Ready(Some(Message::Header(version))) => { Poll::Ready(Some(Ok(Message::Header(version)))) => {
self.state = State::SendHeader { io, version } *this.state = State::SendHeader { io, version }
} }
Async::Ready(Some(_)) => { Poll::Ready(Some(Ok(_))) => {
return Err(ProtocolError::InvalidMessage.into()) return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
} },
Async::Ready(None) => Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
return Err(NegotiationError::from( Poll::Ready(None) =>
return Poll::Ready(Err(NegotiationError::from(
ProtocolError::IoError( ProtocolError::IoError(
io::ErrorKind::UnexpectedEof.into()))), io::ErrorKind::UnexpectedEof.into())))),
Async::NotReady => { Poll::Pending => {
self.state = State::RecvHeader { io }; *this.state = State::RecvHeader { io };
return Ok(Async::NotReady) return Poll::Pending
} }
} }
} }
State::SendHeader { mut io, version } => { State::SendHeader { mut io, version } => {
if io.start_send(Message::Header(version))?.is_not_ready() { match Pin::new(&mut io).poll_ready(cx) {
return Ok(Async::NotReady) Poll::Pending => {
*this.state = State::SendHeader { io, version };
return Poll::Pending
},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
} }
self.state = match version {
if let Err(err) = Pin::new(&mut io).start_send(Message::Header(version)) {
return Poll::Ready(Err(From::from(err)));
}
*this.state = match version {
Version::V1 => State::Flush { io }, Version::V1 => State::Flush { io },
Version::V1Lazy => State::RecvMessage { io }, Version::V1Lazy => State::RecvMessage { io },
} }
} }
State::RecvMessage { mut io } => { State::RecvMessage { mut io } => {
let msg = match io.poll() { let msg = match Pin::new(&mut io).poll_next(cx) {
Ok(Async::Ready(Some(msg))) => msg, Poll::Ready(Some(Ok(msg))) => msg,
Ok(Async::Ready(None)) => Poll::Ready(None) =>
return Err(NegotiationError::from( return Poll::Ready(Err(NegotiationError::from(
ProtocolError::IoError( ProtocolError::IoError(
io::ErrorKind::UnexpectedEof.into()))), io::ErrorKind::UnexpectedEof.into())))),
Ok(Async::NotReady) => { Poll::Pending => {
self.state = State::RecvMessage { io }; *this.state = State::RecvMessage { io };
return Ok(Async::NotReady) return Poll::Pending;
} }
Err(e) => return Err(e.into()) Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
}; };
match msg { match msg {
Message::ListProtocols => { Message::ListProtocols => {
let supported = self.protocols.iter().map(|(_,p)| p).cloned().collect(); let supported = this.protocols.iter().map(|(_,p)| p).cloned().collect();
let message = Message::Protocols(supported); let message = Message::Protocols(supported);
self.state = State::SendMessage { io, message, protocol: None } *this.state = State::SendMessage { io, message, protocol: None }
} }
Message::Protocol(p) => { Message::Protocol(p) => {
let protocol = self.protocols.iter().find_map(|(name, proto)| { let protocol = this.protocols.iter().find_map(|(name, proto)| {
if &p == proto { if &p == proto {
Some(name.clone()) Some(name.clone())
} else { } else {
@ -160,45 +176,60 @@ where
}); });
let message = if protocol.is_some() { let message = if protocol.is_some() {
debug!("Listener: confirming protocol: {}", p); log::debug!("Listener: confirming protocol: {}", p);
Message::Protocol(p.clone()) Message::Protocol(p.clone())
} else { } else {
debug!("Listener: rejecting protocol: {}", log::debug!("Listener: rejecting protocol: {}",
String::from_utf8_lossy(p.as_ref())); String::from_utf8_lossy(p.as_ref()));
Message::NotAvailable Message::NotAvailable
}; };
self.state = State::SendMessage { io, message, protocol }; *this.state = State::SendMessage { io, message, protocol };
} }
_ => return Err(ProtocolError::InvalidMessage.into()) _ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
} }
} }
State::SendMessage { mut io, message, protocol } => { State::SendMessage { mut io, message, protocol } => {
if let AsyncSink::NotReady(message) = io.start_send(message)? { match Pin::new(&mut io).poll_ready(cx) {
self.state = State::SendMessage { io, message, protocol }; Poll::Pending => {
return Ok(Async::NotReady) *this.state = State::SendMessage { io, message, protocol };
}; return Poll::Pending
},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
if let Err(err) = Pin::new(&mut io).start_send(message) {
return Poll::Ready(Err(From::from(err)));
}
// If a protocol has been selected, finish negotiation. // If a protocol has been selected, finish negotiation.
// Otherwise flush the sink and expect to receive another // Otherwise flush the sink and expect to receive another
// message. // message.
self.state = match protocol { *this.state = match protocol {
Some(protocol) => { Some(protocol) => {
debug!("Listener: sent confirmed protocol: {}", log::debug!("Listener: sent confirmed protocol: {}",
String::from_utf8_lossy(protocol.as_ref())); String::from_utf8_lossy(protocol.as_ref()));
let (io, remaining) = io.into_inner(); let (io, remaining) = io.into_inner();
let io = Negotiated::completed(io, remaining); let io = Negotiated::completed(io, remaining);
return Ok(Async::Ready((protocol, io))) return Poll::Ready(Ok((protocol, io)));
} }
None => State::Flush { io } None => State::Flush { io }
}; };
} }
State::Flush { mut io } => { State::Flush { mut io } => {
if io.poll_complete()?.is_not_ready() { match Pin::new(&mut io).poll_flush(cx) {
self.state = State::Flush { io }; Poll::Pending => {
return Ok(Async::NotReady) *this.state = State::Flush { io };
return Poll::Pending
},
Poll::Ready(Ok(())) => *this.state = State::RecvMessage { io },
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
} }
self.state = State::RecvMessage { io }
} }
State::Done => panic!("State::poll called after completion") State::Done => panic!("State::poll called after completion")
} }
} }

View File

@ -18,12 +18,12 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
use bytes::{BytesMut, Buf};
use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError}; use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError};
use futures::{prelude::*, Async, try_ready};
use log::debug; use bytes::{BytesMut, Buf};
use tokio_io::{AsyncRead, AsyncWrite}; use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
use std::{mem, io, fmt, error::Error}; use pin_project::{pin_project, project};
use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
/// An I/O stream that has settled on an (application-layer) protocol to use. /// An I/O stream that has settled on an (application-layer) protocol to use.
/// ///
@ -36,28 +36,40 @@ use std::{mem, io, fmt, error::Error};
/// ///
/// Reading from a `Negotiated` I/O stream that still has pending negotiation /// Reading from a `Negotiated` I/O stream that still has pending negotiation
/// protocol data to send implicitly triggers flushing of all yet unsent data. /// protocol data to send implicitly triggers flushing of all yet unsent data.
#[pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct Negotiated<TInner> { pub struct Negotiated<TInner> {
#[pin]
state: State<TInner> state: State<TInner>
} }
/// A `Future` that waits on the completion of protocol negotiation. /// A `Future` that waits on the completion of protocol negotiation.
#[derive(Debug)] #[derive(Debug)]
pub struct NegotiatedComplete<TInner> { pub struct NegotiatedComplete<TInner> {
inner: Option<Negotiated<TInner>> inner: Option<Negotiated<TInner>>,
} }
impl<TInner: AsyncRead + AsyncWrite> Future for NegotiatedComplete<TInner> { impl<TInner> Future for NegotiatedComplete<TInner>
type Item = Negotiated<TInner>; where
type Error = NegotiationError; // `Unpin` is required not because of implementation details but because we produce the
// `Negotiated` as the output of the future.
TInner: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<Negotiated<TInner>, NegotiationError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut io = self.inner.take().expect("NegotiatedFuture called after completion."); let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
if io.poll()?.is_not_ready() { match Negotiated::poll(Pin::new(&mut io), cx) {
self.inner = Some(io); Poll::Pending => {
return Ok(Async::NotReady) self.inner = Some(io);
return Poll::Pending
},
Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
Poll::Ready(Err(err)) => {
self.inner = Some(io);
return Poll::Ready(Err(err));
}
} }
return Ok(Async::Ready(io))
} }
} }
@ -75,66 +87,67 @@ impl<TInner> Negotiated<TInner> {
} }
/// Polls the `Negotiated` for completion. /// Polls the `Negotiated` for completion.
fn poll(&mut self) -> Poll<(), NegotiationError> #[project]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), NegotiationError>>
where where
TInner: AsyncRead + AsyncWrite TInner: AsyncRead + AsyncWrite + Unpin
{ {
// Flush any pending negotiation data. // Flush any pending negotiation data.
match self.poll_flush() { match self.as_mut().poll_flush(cx) {
Ok(Async::Ready(())) => {}, Poll::Ready(Ok(())) => {},
Ok(Async::NotReady) => return Ok(Async::NotReady), Poll::Pending => return Poll::Pending,
Err(e) => { Poll::Ready(Err(e)) => {
// If the remote closed the stream, it is important to still // If the remote closed the stream, it is important to still
// continue reading the data that was sent, if any. // continue reading the data that was sent, if any.
if e.kind() != io::ErrorKind::WriteZero { if e.kind() != io::ErrorKind::WriteZero {
return Err(e.into()) return Poll::Ready(Err(e.into()))
} }
} }
} }
if let State::Completed { remaining, .. } = &mut self.state { let mut this = self.project();
let _ = remaining.split_to(remaining.len()); // Drop remaining data flushed above.
return Ok(Async::Ready(())) #[project]
match this.state.as_mut().project() {
State::Completed { remaining, .. } => {
debug_assert!(remaining.is_empty());
return Poll::Ready(Ok(()))
}
_ => {}
} }
// Read outstanding protocol negotiation messages. // Read outstanding protocol negotiation messages.
loop { loop {
match mem::replace(&mut self.state, State::Invalid) { match mem::replace(&mut *this.state, State::Invalid) {
State::Expecting { mut io, protocol, version } => { State::Expecting { mut io, protocol, version } => {
let msg = match io.poll() { let msg = match Pin::new(&mut io).poll_next(cx)? {
Ok(Async::Ready(Some(msg))) => msg, Poll::Ready(Some(msg)) => msg,
Ok(Async::NotReady) => { Poll::Pending => {
self.state = State::Expecting { io, protocol, version }; *this.state = State::Expecting { io, protocol, version };
return Ok(Async::NotReady) return Poll::Pending
} },
Ok(Async::Ready(None)) => { Poll::Ready(None) => {
self.state = State::Expecting { io, protocol, version }; return Poll::Ready(Err(ProtocolError::IoError(
return Err(ProtocolError::IoError( io::ErrorKind::UnexpectedEof.into()).into()));
io::ErrorKind::UnexpectedEof.into()).into())
}
Err(err) => {
self.state = State::Expecting { io, protocol, version };
return Err(err.into())
} }
}; };
if let Message::Header(v) = &msg { if let Message::Header(v) = &msg {
if v == &version { if *v == version {
self.state = State::Expecting { io, protocol, version };
continue continue
} }
} }
if let Message::Protocol(p) = &msg { if let Message::Protocol(p) = &msg {
if p.as_ref() == protocol.as_ref() { if p.as_ref() == protocol.as_ref() {
debug!("Negotiated: Received confirmation for protocol: {}", p); log::debug!("Negotiated: Received confirmation for protocol: {}", p);
let (io, remaining) = io.into_inner(); let (io, remaining) = io.into_inner();
self.state = State::Completed { io, remaining }; *this.state = State::Completed { io, remaining };
return Ok(Async::Ready(())) return Poll::Ready(Ok(()));
} }
} }
return Err(NegotiationError::Failed) return Poll::Ready(Err(NegotiationError::Failed));
} }
_ => panic!("Negotiated: Invalid state") _ => panic!("Negotiated: Invalid state")
@ -142,7 +155,7 @@ impl<TInner> Negotiated<TInner> {
} }
} }
/// Returns a `NegotiatedComplete` future that waits for protocol /// Returns a [`NegotiatedComplete`] future that waits for protocol
/// negotiation to complete. /// negotiation to complete.
pub fn complete(self) -> NegotiatedComplete<TInner> { pub fn complete(self) -> NegotiatedComplete<TInner> {
NegotiatedComplete { inner: Some(self) } NegotiatedComplete { inner: Some(self) }
@ -150,12 +163,14 @@ impl<TInner> Negotiated<TInner> {
} }
/// The states of a `Negotiated` I/O stream. /// The states of a `Negotiated` I/O stream.
#[pin_project]
#[derive(Debug)] #[derive(Debug)]
enum State<R> { enum State<R> {
/// In this state, a `Negotiated` is still expecting to /// In this state, a `Negotiated` is still expecting to
/// receive confirmation of the protocol it as settled on. /// receive confirmation of the protocol it as settled on.
Expecting { Expecting {
/// The underlying I/O stream. /// The underlying I/O stream.
#[pin]
io: MessageReader<R>, io: MessageReader<R>,
/// The expected protocol (i.e. name and version). /// The expected protocol (i.e. name and version).
protocol: Protocol, protocol: Protocol,
@ -167,113 +182,157 @@ enum State<R> {
/// only be pending the sending of the final acknowledgement, /// only be pending the sending of the final acknowledgement,
/// which is prepended to / combined with the next write for /// which is prepended to / combined with the next write for
/// efficiency. /// efficiency.
Completed { io: R, remaining: BytesMut }, Completed { #[pin] io: R, remaining: BytesMut },
/// Temporary state while moving the `io` resource from /// Temporary state while moving the `io` resource from
/// `Expecting` to `Completed`. /// `Expecting` to `Completed`.
Invalid, Invalid,
} }
impl<R> io::Read for Negotiated<R> impl<TInner> AsyncRead for Negotiated<TInner>
where where
R: AsyncRead + AsyncWrite TInner: AsyncRead + AsyncWrite + Unpin
{ {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { #[project]
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8])
-> Poll<Result<usize, io::Error>>
{
loop { loop {
if let State::Completed { io, remaining } = &mut self.state { #[project]
// If protocol negotiation is complete and there is no match self.as_mut().project().state.project() {
// remaining data to be flushed, commence with reading. State::Completed { io, remaining } => {
if remaining.is_empty() { // If protocol negotiation is complete and there is no
return io.read(buf) // remaining data to be flushed, commence with reading.
} if remaining.is_empty() {
return io.poll_read(cx, buf)
}
},
_ => {}
} }
// Poll the `Negotiated`, driving protocol negotiation to completion, // Poll the `Negotiated`, driving protocol negotiation to completion,
// including flushing of any remaining data. // including flushing of any remaining data.
let result = self.poll(); match self.as_mut().poll(cx) {
Poll::Ready(Ok(())) => {},
// There is still remaining data to be sent before data relating Poll::Pending => return Poll::Pending,
// to the negotiated protocol can be read. Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
if let Ok(Async::NotReady) = result {
return Err(io::ErrorKind::WouldBlock.into())
}
if let Err(err) = result {
return Err(err.into())
} }
} }
} }
}
impl<TInner> AsyncRead for Negotiated<TInner> // TODO: implement once method is stabilized in the futures crate
where /*unsafe fn initializer(&self) -> Initializer {
TInner: AsyncRead + AsyncWrite
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match &self.state { match &self.state {
State::Completed { io, .. } => State::Completed { io, .. } => io.initializer(),
io.prepare_uninitialized_buffer(buf), State::Expecting { io, .. } => io.inner_ref().initializer(),
State::Expecting { io, .. } => State::Invalid => panic!("Negotiated: Invalid state"),
io.inner_ref().prepare_uninitialized_buffer(buf),
State::Invalid => panic!("Negotiated: Invalid state")
} }
} }*/
}
impl<TInner> io::Write for Negotiated<TInner> #[project]
where fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context, bufs: &mut [IoSliceMut])
TInner: AsyncWrite -> Poll<Result<usize, io::Error>>
{ {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { loop {
match &mut self.state { #[project]
State::Completed { io, ref mut remaining } => { match self.as_mut().project().state.project() {
while !remaining.is_empty() { State::Completed { io, remaining } => {
let n = io.write(&remaining)?; // If protocol negotiation is complete and there is no
if n == 0 { // remaining data to be flushed, commence with reading.
return Err(io::ErrorKind::WriteZero.into()) if remaining.is_empty() {
return io.poll_read_vectored(cx, bufs)
} }
remaining.advance(n); },
} _ => {}
io.write(buf) }
},
State::Expecting { io, .. } => io.write(buf),
State::Invalid => panic!("Negotiated: Invalid state")
}
}
fn flush(&mut self) -> io::Result<()> { // Poll the `Negotiated`, driving protocol negotiation to completion,
match &mut self.state { // including flushing of any remaining data.
State::Completed { io, ref mut remaining } => { match self.as_mut().poll(cx) {
while !remaining.is_empty() { Poll::Ready(Ok(())) => {},
let n = io.write(remaining)?; Poll::Pending => return Poll::Pending,
if n == 0 { Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
return Err(io::Error::new( }
io::ErrorKind::WriteZero,
"Failed to write remaining buffer."))
}
remaining.advance(n);
}
io.flush()
},
State::Expecting { io, .. } => io.flush(),
State::Invalid => panic!("Negotiated: Invalid state")
} }
} }
} }
impl<TInner> AsyncWrite for Negotiated<TInner> impl<TInner> AsyncWrite for Negotiated<TInner>
where where
TInner: AsyncWrite + AsyncRead TInner: AsyncWrite + AsyncRead + Unpin
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { #[project]
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
#[project]
match self.project().state.project() {
State::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_write(cx, buf)
},
State::Expecting { io, .. } => io.poll_write(cx, buf),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
#[project]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
#[project]
match self.project().state.project() {
State::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_flush(cx)
},
State::Expecting { io, .. } => io.poll_flush(cx),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
#[project]
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
// Ensure all data has been flushed and expected negotiation messages // Ensure all data has been flushed and expected negotiation messages
// have been received. // have been received.
try_ready!(self.poll().map_err(Into::<io::Error>::into)); ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
// Continue with the shutdown of the underlying I/O stream. // Continue with the shutdown of the underlying I/O stream.
match &mut self.state { #[project]
State::Completed { io, .. } => io.shutdown(), match self.project().state.project() {
State::Expecting { io, .. } => io.shutdown(), State::Completed { io, .. } => io.poll_close(cx),
State::Invalid => panic!("Negotiated: Invalid state") State::Expecting { io, .. } => io.poll_close(cx),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
#[project]
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice])
-> Poll<Result<usize, io::Error>>
{
#[project]
match self.project().state.project() {
State::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_write_vectored(cx, bufs)
},
State::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
State::Invalid => panic!("Negotiated: Invalid state"),
} }
} }
} }
@ -300,12 +359,12 @@ impl From<io::Error> for NegotiationError {
} }
} }
impl Into<io::Error> for NegotiationError { impl From<NegotiationError> for io::Error {
fn into(self) -> io::Error { fn from(err: NegotiationError) -> io::Error {
if let NegotiationError::ProtocolError(e) = self { if let NegotiationError::ProtocolError(e) = err {
return e.into() return e.into()
} }
io::Error::new(io::ErrorKind::Other, self) io::Error::new(io::ErrorKind::Other, err)
} }
} }
@ -333,27 +392,33 @@ impl fmt::Display for NegotiationError {
mod tests { mod tests {
use super::*; use super::*;
use quickcheck::*; use quickcheck::*;
use std::io::Write; use std::{io::Write, task::Poll};
/// An I/O resource with a fixed write capacity (total and per write op). /// An I/O resource with a fixed write capacity (total and per write op).
struct Capped { buf: Vec<u8>, step: usize } struct Capped { buf: Vec<u8>, step: usize }
impl io::Write for Capped { impl AsyncRead for Capped {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn poll_read(self: Pin<&mut Self>, _: &mut Context, _: &mut [u8]) -> Poll<Result<usize, io::Error>> {
if self.buf.len() + buf.len() > self.buf.capacity() { unreachable!()
return Err(io::ErrorKind::WriteZero.into())
}
self.buf.write(&buf[.. usize::min(self.step, buf.len())])
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
} }
} }
impl AsyncWrite for Capped { impl AsyncWrite for Capped {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn poll_write(mut self: Pin<&mut Self>, _: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
Ok(().into()) if self.buf.len() + buf.len() > self.buf.capacity() {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
let len = usize::min(self.step, buf.len());
let n = Write::write(&mut self.buf, &buf[.. len]).unwrap();
Poll::Ready(Ok(n))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
} }
} }
@ -369,7 +434,7 @@ mod tests {
loop { loop {
// Write until `new` has been fully written or the capped buffer runs // Write until `new` has been fully written or the capped buffer runs
// over capacity and yields WriteZero. // over capacity and yields WriteZero.
match io.write(&new[written..]) { match future::poll_fn(|cx| Pin::new(&mut io).poll_write(cx, &new[written..])).now_or_never().unwrap() {
Ok(n) => Ok(n) =>
if let State::Completed { remaining, .. } = &io.state { if let State::Completed { remaining, .. } = &io.state {
assert!(remaining.is_empty()); assert!(remaining.is_empty());
@ -388,7 +453,7 @@ mod tests {
return TestResult::failed() return TestResult::failed()
} }
} }
Err(e) => panic!("Unexpected error: {:?}", e) Err(e) => panic!("Unexpected error: {:?}", e),
} }
} }
} }

View File

@ -25,12 +25,11 @@
//! `Stream` and `Sink` implementations of `MessageIO` and //! `Stream` and `Sink` implementations of `MessageIO` and
//! `MessageReader`. //! `MessageReader`.
use bytes::{Bytes, BytesMut, BufMut};
use crate::length_delimited::{LengthDelimited, LengthDelimitedReader}; use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
use futures::{prelude::*, try_ready};
use log::trace; use bytes::{Bytes, BytesMut, BufMut};
use std::{io, fmt, error::Error, convert::TryFrom}; use futures::{prelude::*, io::IoSlice, ready};
use tokio_io::{AsyncRead, AsyncWrite}; use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}};
use unsigned_varint as uvi; use unsigned_varint as uvi;
/// The maximum number of supported protocols that can be processed. /// The maximum number of supported protocols that can be processed.
@ -264,7 +263,9 @@ impl Message {
} }
/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s. /// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s.
#[pin_project::pin_project]
pub struct MessageIO<R> { pub struct MessageIO<R> {
#[pin]
inner: LengthDelimited<R>, inner: LengthDelimited<R>,
} }
@ -277,8 +278,8 @@ impl<R> MessageIO<R> {
Self { inner: LengthDelimited::new(inner) } Self { inner: LengthDelimited::new(inner) }
} }
/// Converts the `MessageIO` into a `MessageReader`, dropping the /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the
/// `Message`-oriented `Sink` in favour of direct `AsyncWrite` access /// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access
/// to the underlying I/O stream. /// to the underlying I/O stream.
/// ///
/// This is typically done if further negotiation messages are expected to be /// This is typically done if further negotiation messages are expected to be
@ -288,7 +289,7 @@ impl<R> MessageIO<R> {
MessageReader { inner: self.inner.into_reader() } MessageReader { inner: self.inner.into_reader() }
} }
/// Drops the `MessageIO` resource, yielding the underlying I/O stream /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream
/// together with the remaining write buffer containing the protocol /// together with the remaining write buffer containing the protocol
/// negotiation frame data that has not yet been written to the I/O stream. /// negotiation frame data that has not yet been written to the I/O stream.
/// ///
@ -309,28 +310,28 @@ impl<R> MessageIO<R> {
} }
} }
impl<R> Sink for MessageIO<R> impl<R> Sink<Message> for MessageIO<R>
where where
R: AsyncWrite, R: AsyncWrite,
{ {
type SinkItem = Message; type Error = ProtocolError;
type SinkError = ProtocolError;
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx).map_err(From::from)
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
let mut buf = BytesMut::new(); let mut buf = BytesMut::new();
msg.encode(&mut buf)?; item.encode(&mut buf)?;
match self.inner.start_send(buf.freeze())? { self.project().inner.start_send(buf.freeze()).map_err(From::from)
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(msg)),
AsyncSink::Ready => Ok(AsyncSink::Ready),
}
} }
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Ok(self.inner.poll_complete()?) self.project().inner.poll_flush(cx).map_err(From::from)
} }
fn close(&mut self) -> Poll<(), Self::SinkError> { fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
Ok(self.inner.close()?) self.project().inner.poll_close(cx).map_err(From::from)
} }
} }
@ -338,18 +339,24 @@ impl<R> Stream for MessageIO<R>
where where
R: AsyncRead R: AsyncRead
{ {
type Item = Message; type Item = Result<Message, ProtocolError>;
type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
poll_stream(&mut self.inner) match poll_stream(self.project().inner, cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(From::from(err)))),
}
} }
} }
/// A `MessageReader` implements a `Stream` of `Message`s on an underlying /// A `MessageReader` implements a `Stream` of `Message`s on an underlying
/// I/O resource combined with direct `AsyncWrite` access. /// I/O resource combined with direct `AsyncWrite` access.
#[pin_project::pin_project]
#[derive(Debug)] #[derive(Debug)]
pub struct MessageReader<R> { pub struct MessageReader<R> {
#[pin]
inner: LengthDelimitedReader<R> inner: LengthDelimitedReader<R>
} }
@ -373,35 +380,16 @@ impl<R> MessageReader<R> {
pub fn into_inner(self) -> (R, BytesMut) { pub fn into_inner(self) -> (R, BytesMut) {
self.inner.into_inner() self.inner.into_inner()
} }
/// Returns a reference to the underlying I/O stream.
pub fn inner_ref(&self) -> &R {
self.inner.inner_ref()
}
} }
impl<R> Stream for MessageReader<R> impl<R> Stream for MessageReader<R>
where where
R: AsyncRead R: AsyncRead
{ {
type Item = Message; type Item = Result<Message, ProtocolError>;
type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
poll_stream(&mut self.inner) poll_stream(self.project().inner, cx)
}
}
impl<R> io::Write for MessageReader<R>
where
R: AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
} }
} }
@ -409,24 +397,39 @@ impl<TInner> AsyncWrite for MessageReader<TInner>
where where
TInner: AsyncWrite TInner: AsyncWrite
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
self.inner.shutdown() self.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_close(cx)
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs)
} }
} }
fn poll_stream<S>(stream: &mut S) -> Poll<Option<Message>, ProtocolError> fn poll_stream<S>(stream: Pin<&mut S>, cx: &mut Context) -> Poll<Option<Result<Message, ProtocolError>>>
where where
S: Stream<Item = Bytes, Error = io::Error>, S: Stream<Item = Result<Bytes, io::Error>>,
{ {
let msg = if let Some(msg) = try_ready!(stream.poll()) { let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
Message::decode(msg)? match Message::decode(msg) {
Ok(m) => m,
Err(err) => return Poll::Ready(Some(Err(err))),
}
} else { } else {
return Ok(Async::Ready(None)) return Poll::Ready(None)
}; };
trace!("Received message: {:?}", msg); log::trace!("Received message: {:?}", msg);
Ok(Async::Ready(Some(msg))) Poll::Ready(Some(Ok(msg)))
} }
/// A protocol error. /// A protocol error.

View File

@ -25,164 +25,156 @@
use crate::{Version, NegotiationError}; use crate::{Version, NegotiationError};
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial}; use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
use crate::{dialer_select_proto, listener_select_proto}; use crate::{dialer_select_proto, listener_select_proto};
use async_std::net::{TcpListener, TcpStream};
use futures::prelude::*; use futures::prelude::*;
use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream};
use tokio_io::io as nio;
#[test] #[test]
fn select_proto_basic() { fn select_proto_basic() {
fn run(version: Version) { async fn run(version: Version) {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap(); let listener_addr = listener.local_addr().unwrap();
let server = listener let server = async_std::task::spawn(async move {
.incoming() let connec = listener.accept().await.unwrap().0;
.into_future() let protos = vec![b"/proto1", b"/proto2"];
.map(|s| s.0.unwrap()) let (proto, mut io) = listener_select_proto(connec, protos).await.unwrap();
.map_err(|(e, _)| e.into()) assert_eq!(proto, b"/proto2");
.and_then(move |connec| {
let protos = vec![b"/proto1", b"/proto2"];
listener_select_proto(connec, protos)
})
.and_then(|(proto, io)| {
nio::write_all(io, b"pong").from_err().map(move |_| proto)
});
let client = TcpStream::connect(&listener_addr) let mut out = vec![0; 32];
.from_err() let n = io.read(&mut out).await.unwrap();
.and_then(move |connec| { out.truncate(n);
let protos = vec![b"/proto3", b"/proto2"]; assert_eq!(out, b"ping");
dialer_select_proto(connec, protos, version)
})
.and_then(|(proto, io)| {
nio::write_all(io, b"ping").from_err().map(move |(io, _)| (proto, io))
})
.and_then(|(proto, io)| {
nio::read_exact(io, [0; 4]).from_err().map(move |(_, msg)| {
assert_eq!(&msg, b"pong");
proto
})
});
let mut rt = Runtime::new().unwrap(); io.write_all(b"pong").await.unwrap();
let (dialer_chosen, listener_chosen) = io.flush().await.unwrap();
rt.block_on(client.join(server)).unwrap(); });
assert_eq!(dialer_chosen, b"/proto2"); let client = async_std::task::spawn(async move {
assert_eq!(listener_chosen, b"/proto2"); let connec = TcpStream::connect(&listener_addr).await.unwrap();
let protos = vec![b"/proto3", b"/proto2"];
let (proto, mut io) = dialer_select_proto(connec, protos.into_iter(), version)
.await.unwrap();
assert_eq!(proto, b"/proto2");
io.write_all(b"ping").await.unwrap();
io.flush().await.unwrap();
let mut out = vec![0; 32];
let n = io.read(&mut out).await.unwrap();
out.truncate(n);
assert_eq!(out, b"pong");
});
server.await;
client.await;
} }
run(Version::V1); async_std::task::block_on(run(Version::V1));
run(Version::V1Lazy); async_std::task::block_on(run(Version::V1Lazy));
} }
#[test] #[test]
fn no_protocol_found() { fn no_protocol_found() {
fn run(version: Version) { async fn run(version: Version) {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap(); let listener_addr = listener.local_addr().unwrap();
let server = listener let server = async_std::task::spawn(async move {
.incoming() let connec = listener.accept().await.unwrap().0;
.into_future() let protos = vec![b"/proto1", b"/proto2"];
.map(|s| s.0.unwrap()) let io = match listener_select_proto(connec, protos).await {
.map_err(|(e, _)| e.into()) Ok((_, io)) => io,
.and_then(move |connec| { // We don't explicitly check for `Failed` because the client might close the connection when it
let protos = vec![b"/proto1", b"/proto2"]; // realizes that we have no protocol in common.
listener_select_proto(connec, protos) Err(_) => return,
}) };
.and_then(|(proto, io)| io.complete().map(move |_| proto)); match io.complete().await {
Err(NegotiationError::Failed) => {},
_ => panic!(),
}
});
let client = TcpStream::connect(&listener_addr) let client = async_std::task::spawn(async move {
.from_err() let connec = TcpStream::connect(&listener_addr).await.unwrap();
.and_then(move |connec| { let protos = vec![b"/proto3", b"/proto4"];
let protos = vec![b"/proto3", b"/proto4"]; let io = match dialer_select_proto(connec, protos.into_iter(), version).await {
dialer_select_proto(connec, protos, version) Err(NegotiationError::Failed) => return,
}) Ok((_, io)) => io,
.and_then(|(proto, io)| io.complete().map(move |_| proto)); Err(_) => panic!()
};
match io.complete().await {
Err(NegotiationError::Failed) => {},
_ => panic!(),
}
});
let mut rt = Runtime::new().unwrap(); server.await;
match rt.block_on(client.join(server)) { client.await;
Err(NegotiationError::Failed) => (),
e => panic!("{:?}", e),
}
} }
run(Version::V1); async_std::task::block_on(run(Version::V1));
run(Version::V1Lazy); async_std::task::block_on(run(Version::V1Lazy));
} }
#[test] #[test]
fn select_proto_parallel() { fn select_proto_parallel() {
fn run(version: Version) { async fn run(version: Version) {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap(); let listener_addr = listener.local_addr().unwrap();
let server = listener let server = async_std::task::spawn(async move {
.incoming() let connec = listener.accept().await.unwrap().0;
.into_future() let protos = vec![b"/proto1", b"/proto2"];
.map(|s| s.0.unwrap()) let (proto, io) = listener_select_proto(connec, protos).await.unwrap();
.map_err(|(e, _)| e.into()) assert_eq!(proto, b"/proto2");
.and_then(move |connec| { io.complete().await.unwrap();
let protos = vec![b"/proto1", b"/proto2"]; });
listener_select_proto(connec, protos)
})
.and_then(|(proto, io)| io.complete().map(move |_| proto));
let client = TcpStream::connect(&listener_addr) let client = async_std::task::spawn(async move {
.from_err() let connec = TcpStream::connect(&listener_addr).await.unwrap();
.and_then(move |connec| { let protos = vec![b"/proto3", b"/proto2"];
let protos = vec![b"/proto3", b"/proto2"]; let (proto, io) = dialer_select_proto_parallel(connec, protos.into_iter(), version)
dialer_select_proto_parallel(connec, protos.into_iter(), version) .await.unwrap();
}) assert_eq!(proto, b"/proto2");
.and_then(|(proto, io)| io.complete().map(move |_| proto)); io.complete().await.unwrap();
});
let mut rt = Runtime::new().unwrap(); server.await;
let (dialer_chosen, listener_chosen) = client.await;
rt.block_on(client.join(server)).unwrap();
assert_eq!(dialer_chosen, b"/proto2");
assert_eq!(listener_chosen, b"/proto2");
} }
run(Version::V1); async_std::task::block_on(run(Version::V1));
run(Version::V1Lazy); async_std::task::block_on(run(Version::V1Lazy));
} }
#[test] #[test]
fn select_proto_serial() { fn select_proto_serial() {
fn run(version: Version) { async fn run(version: Version) {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap(); let listener_addr = listener.local_addr().unwrap();
let server = listener let server = async_std::task::spawn(async move {
.incoming() let connec = listener.accept().await.unwrap().0;
.into_future() let protos = vec![b"/proto1", b"/proto2"];
.map(|s| s.0.unwrap()) let (proto, io) = listener_select_proto(connec, protos).await.unwrap();
.map_err(|(e, _)| e.into()) assert_eq!(proto, b"/proto2");
.and_then(move |connec| { io.complete().await.unwrap();
let protos = vec![b"/proto1", b"/proto2"]; });
listener_select_proto(connec, protos)
})
.and_then(|(proto, io)| io.complete().map(move |_| proto));
let client = TcpStream::connect(&listener_addr) let client = async_std::task::spawn(async move {
.from_err() let connec = TcpStream::connect(&listener_addr).await.unwrap();
.and_then(move |connec| { let protos = vec![b"/proto3", b"/proto2"];
let protos = vec![b"/proto3", b"/proto2"]; let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version)
dialer_select_proto_serial(connec, protos.into_iter(), version) .await.unwrap();
}) assert_eq!(proto, b"/proto2");
.and_then(|(proto, io)| io.complete().map(move |_| proto)); io.complete().await.unwrap();
});
let mut rt = Runtime::new().unwrap(); server.await;
let (dialer_chosen, listener_chosen) = client.await;
rt.block_on(client.join(server)).unwrap();
assert_eq!(dialer_chosen, b"/proto2");
assert_eq!(listener_chosen, b"/proto2");
} }
run(Version::V1); async_std::task::block_on(run(Version::V1));
run(Version::V1Lazy); async_std::task::block_on(run(Version::V1Lazy));
} }