Update multistream-select to stable futures (#1484)

* Update multistream-select to stable futures

* Fix intradoc links
This commit is contained in:
Pierre Krieger
2020-03-11 14:49:41 +01:00
committed by GitHub
parent 2084fadd86
commit 31271fc824
11 changed files with 851 additions and 652 deletions

View File

@ -15,7 +15,7 @@ bs58 = "0.3.0"
ed25519-dalek = "1.0.0-pre.3"
either = "1.5"
fnv = "1.0"
futures = { version = "0.3.1", features = ["compat", "io-compat", "executor", "thread-pool"] }
futures = { version = "0.3.1", features = ["executor", "thread-pool"] }
futures-timer = "3"
lazy_static = "1.2"
libsecp256k1 = { version = "0.3.1", optional = true }

View File

@ -41,7 +41,7 @@ mod keys_proto {
/// Multi-address re-export.
pub use multiaddr;
pub type Negotiated<T> = futures::compat::Compat01As03<multistream_select::Negotiated<futures::compat::Compat<T>>>;
pub type Negotiated<T> = multistream_select::Negotiated<T>;
mod peer_id;
mod translation;

View File

@ -20,7 +20,7 @@
use crate::{ConnectedPoint, Negotiated};
use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName};
use futures::{future::Either, prelude::*, compat::Compat, compat::Compat01As03, compat::Future01CompatExt};
use futures::{future::Either, prelude::*};
use log::debug;
use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture};
use std::{iter, mem, pin::Pin, task::Context, task::Poll};
@ -48,7 +48,7 @@ where
U: InboundUpgrade<Negotiated<C>>,
{
let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>);
let future = multistream_select::listener_select_proto(Compat::new(conn), iter).compat();
let future = multistream_select::listener_select_proto(conn, iter);
InboundUpgradeApply {
inner: InboundUpgradeApplyState::Init { future, upgrade: up }
}
@ -61,7 +61,7 @@ where
U: OutboundUpgrade<Negotiated<C>>
{
let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>);
let future = multistream_select::dialer_select_proto(Compat::new(conn), iter, v).compat();
let future = multistream_select::dialer_select_proto(conn, iter, v);
OutboundUpgradeApply {
inner: OutboundUpgradeApplyState::Init { future, upgrade: up }
}
@ -82,7 +82,7 @@ where
U: InboundUpgrade<Negotiated<C>>,
{
Init {
future: Compat01As03<ListenerSelectFuture<Compat<C>, NameWrap<U::Info>>>,
future: ListenerSelectFuture<C, NameWrap<U::Info>>,
upgrade: U,
},
Upgrade {
@ -117,7 +117,7 @@ where
}
};
self.inner = InboundUpgradeApplyState::Upgrade {
future: Box::pin(upgrade.upgrade_inbound(Compat01As03::new(io), info.0))
future: Box::pin(upgrade.upgrade_inbound(io, info.0))
};
}
InboundUpgradeApplyState::Upgrade { mut future } => {
@ -158,7 +158,7 @@ where
U: OutboundUpgrade<Negotiated<C>>
{
Init {
future: Compat01As03<DialerSelectFuture<Compat<C>, NameWrapIter<<U::InfoIter as IntoIterator>::IntoIter>>>,
future: DialerSelectFuture<C, NameWrapIter<<U::InfoIter as IntoIterator>::IntoIter>>,
upgrade: U
},
Upgrade {
@ -193,7 +193,7 @@ where
}
};
self.inner = OutboundUpgradeApplyState::Upgrade {
future: Box::pin(upgrade.upgrade_outbound(Compat01As03::new(connection), info.0))
future: Box::pin(upgrade.upgrade_outbound(connection, info.0))
};
}
OutboundUpgradeApplyState::Upgrade { mut future } => {

View File

@ -11,14 +11,14 @@ edition = "2018"
[dependencies]
bytes = "0.5"
futures = "0.1"
futures = "0.3"
log = "0.4"
pin-project = "0.4.8"
smallvec = "1.0"
tokio-io = "0.1"
unsigned-varint = "0.3"
unsigned-varint = "0.3.2"
[dev-dependencies]
tokio = "0.1"
tokio-tcp = "0.1"
async-std = "1.2"
quickcheck = "0.9.0"
rand = "0.7.2"
rw-stream-sink = "0.2.1"

View File

@ -20,12 +20,11 @@
//! Protocol negotiation strategies for the peer acting as the dialer.
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
use futures::{future::Either, prelude::*};
use log::debug;
use std::{io, iter, mem, convert::TryFrom};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{Negotiated, NegotiationError};
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
use futures::{future::Either, prelude::*};
use std::{convert::TryFrom as _, io, iter, mem, pin::Pin, task::{Context, Poll}};
/// Returns a `Future` that negotiates a protocol on the given I/O stream
/// for a peer acting as the _dialer_ (or _initiator_).
@ -60,9 +59,9 @@ where
let iter = protocols.into_iter();
// We choose between the "serial" and "parallel" strategies based on the number of protocols.
if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
Either::A(dialer_select_proto_serial(inner, iter, version))
Either::Left(dialer_select_proto_serial(inner, iter, version))
} else {
Either::B(dialer_select_proto_parallel(inner, iter, version))
Either::Right(dialer_select_proto_parallel(inner, iter, version))
}
}
@ -129,6 +128,7 @@ where
/// A `Future` returned by [`dialer_select_proto_serial`] which negotiates
/// a protocol iteratively by considering one protocol after the other.
#[pin_project::pin_project]
pub struct DialerSelectSeq<R, I>
where
R: AsyncRead + AsyncWrite,
@ -155,83 +155,107 @@ where
impl<R, I> Future for DialerSelectSeq<R, I>
where
R: AsyncRead + AsyncWrite,
// The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
// It also makes the implementation considerably easier to write.
R: AsyncRead + AsyncWrite + Unpin,
I: Iterator,
I::Item: AsRef<[u8]>
{
type Item = (I::Item, Negotiated<R>);
type Error = NegotiationError;
type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.state, SeqState::Done) {
match mem::replace(this.state, SeqState::Done) {
SeqState::SendHeader { mut io } => {
if io.start_send(Message::Header(self.version))?.is_not_ready() {
self.state = SeqState::SendHeader { io };
return Ok(Async::NotReady)
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Pending => {
*this.state = SeqState::SendHeader { io };
return Poll::Pending
},
}
let protocol = self.protocols.next().ok_or(NegotiationError::Failed)?;
self.state = SeqState::SendProtocol { io, protocol };
if let Err(err) = Pin::new(&mut io).start_send(Message::Header(*this.version)) {
return Poll::Ready(Err(From::from(err)));
}
let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
*this.state = SeqState::SendProtocol { io, protocol };
}
SeqState::SendProtocol { mut io, protocol } => {
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Pending => {
*this.state = SeqState::SendProtocol { io, protocol };
return Poll::Pending
},
}
let p = Protocol::try_from(protocol.as_ref())?;
if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() {
self.state = SeqState::SendProtocol { io, protocol };
return Ok(Async::NotReady)
if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
return Poll::Ready(Err(From::from(err)));
}
debug!("Dialer: Proposed protocol: {}", p);
if self.protocols.peek().is_some() {
self.state = SeqState::FlushProtocol { io, protocol }
log::debug!("Dialer: Proposed protocol: {}", p);
if this.protocols.peek().is_some() {
*this.state = SeqState::FlushProtocol { io, protocol }
} else {
match self.version {
Version::V1 => self.state = SeqState::FlushProtocol { io, protocol },
match this.version {
Version::V1 => *this.state = SeqState::FlushProtocol { io, protocol },
Version::V1Lazy => {
debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p, self.version);
return Ok(Async::Ready((protocol, io)))
log::debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p, *this.version);
return Poll::Ready(Ok((protocol, io)))
}
}
}
}
SeqState::FlushProtocol { mut io, protocol } => {
if io.poll_complete()?.is_not_ready() {
self.state = SeqState::FlushProtocol { io, protocol };
return Ok(Async::NotReady)
match Pin::new(&mut io).poll_flush(cx)? {
Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol },
Poll::Pending => {
*this.state = SeqState::FlushProtocol { io, protocol };
return Poll::Pending
},
}
self.state = SeqState::AwaitProtocol { io, protocol }
}
SeqState::AwaitProtocol { mut io, protocol } => {
let msg = match io.poll()? {
Async::NotReady => {
self.state = SeqState::AwaitProtocol { io, protocol };
return Ok(Async::NotReady)
let msg = match Pin::new(&mut io).poll_next(cx)? {
Poll::Ready(Some(msg)) => msg,
Poll::Pending => {
*this.state = SeqState::AwaitProtocol { io, protocol };
return Poll::Pending
}
Async::Ready(None) =>
return Err(NegotiationError::from(
io::Error::from(io::ErrorKind::UnexpectedEof))),
Async::Ready(Some(msg)) => msg,
Poll::Ready(None) =>
return Poll::Ready(Err(NegotiationError::from(
io::Error::from(io::ErrorKind::UnexpectedEof)))),
};
match msg {
Message::Header(v) if v == self.version => {
self.state = SeqState::AwaitProtocol { io, protocol };
Message::Header(v) if v == *this.version => {
*this.state = SeqState::AwaitProtocol { io, protocol };
}
Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
debug!("Dialer: Received confirmation for protocol: {}", p);
log::debug!("Dialer: Received confirmation for protocol: {}", p);
let (io, remaining) = io.into_inner();
let io = Negotiated::completed(io, remaining);
return Ok(Async::Ready((protocol, io)))
return Poll::Ready(Ok((protocol, io)));
}
Message::NotAvailable => {
debug!("Dialer: Received rejection of protocol: {}",
log::debug!("Dialer: Received rejection of protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
let protocol = self.protocols.next()
.ok_or(NegotiationError::Failed)?;
self.state = SeqState::SendProtocol { io, protocol }
let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
*this.state = SeqState::SendProtocol { io, protocol }
}
_ => return Err(ProtocolError::InvalidMessage.into())
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
}
}
SeqState::Done => panic!("SeqState::poll called after completion")
}
}
@ -241,6 +265,7 @@ where
/// A `Future` returned by [`dialer_select_proto_parallel`] which negotiates
/// a protocol selectively by considering all supported protocols of the remote
/// "in parallel".
#[pin_project::pin_project]
pub struct DialerSelectPar<R, I>
where
R: AsyncRead + AsyncWrite,
@ -267,76 +292,110 @@ where
impl<R, I> Future for DialerSelectPar<R, I>
where
R: AsyncRead + AsyncWrite,
// The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
// It also makes the implementation considerably easier to write.
R: AsyncRead + AsyncWrite + Unpin,
I: Iterator,
I::Item: AsRef<[u8]>
{
type Item = (I::Item, Negotiated<R>);
type Error = NegotiationError;
type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.state, ParState::Done) {
match mem::replace(this.state, ParState::Done) {
ParState::SendHeader { mut io } => {
if io.start_send(Message::Header(self.version))?.is_not_ready() {
self.state = ParState::SendHeader { io };
return Ok(Async::NotReady)
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Pending => {
*this.state = ParState::SendHeader { io };
return Poll::Pending
},
}
self.state = ParState::SendProtocolsRequest { io };
if let Err(err) = Pin::new(&mut io).start_send(Message::Header(*this.version)) {
return Poll::Ready(Err(From::from(err)));
}
*this.state = ParState::SendProtocolsRequest { io };
}
ParState::SendProtocolsRequest { mut io } => {
if io.start_send(Message::ListProtocols)?.is_not_ready() {
self.state = ParState::SendProtocolsRequest { io };
return Ok(Async::NotReady)
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Pending => {
*this.state = ParState::SendProtocolsRequest { io };
return Poll::Pending
},
}
debug!("Dialer: Requested supported protocols.");
self.state = ParState::Flush { io }
if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) {
return Poll::Ready(Err(From::from(err)));
}
log::debug!("Dialer: Requested supported protocols.");
*this.state = ParState::Flush { io }
}
ParState::Flush { mut io } => {
if io.poll_complete()?.is_not_ready() {
self.state = ParState::Flush { io };
return Ok(Async::NotReady)
match Pin::new(&mut io).poll_flush(cx)? {
Poll::Ready(()) => *this.state = ParState::RecvProtocols { io },
Poll::Pending => {
*this.state = ParState::Flush { io };
return Poll::Pending
},
}
self.state = ParState::RecvProtocols { io }
}
ParState::RecvProtocols { mut io } => {
let msg = match io.poll()? {
Async::NotReady => {
self.state = ParState::RecvProtocols { io };
return Ok(Async::NotReady)
let msg = match Pin::new(&mut io).poll_next(cx)? {
Poll::Ready(Some(msg)) => msg,
Poll::Pending => {
*this.state = ParState::RecvProtocols { io };
return Poll::Pending
}
Async::Ready(None) =>
return Err(NegotiationError::from(
io::Error::from(io::ErrorKind::UnexpectedEof))),
Async::Ready(Some(msg)) => msg,
Poll::Ready(None) =>
return Poll::Ready(Err(NegotiationError::from(
io::Error::from(io::ErrorKind::UnexpectedEof)))),
};
match &msg {
Message::Header(v) if v == &self.version => {
self.state = ParState::RecvProtocols { io }
Message::Header(v) if v == this.version => {
*this.state = ParState::RecvProtocols { io }
}
Message::Protocols(supported) => {
let protocol = self.protocols.by_ref()
let protocol = this.protocols.by_ref()
.find(|p| supported.iter().any(|s|
s.as_ref() == p.as_ref()))
.ok_or(NegotiationError::Failed)?;
debug!("Dialer: Found supported protocol: {}",
log::debug!("Dialer: Found supported protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
self.state = ParState::SendProtocol { io, protocol };
*this.state = ParState::SendProtocol { io, protocol };
}
_ => return Err(ProtocolError::InvalidMessage.into())
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
}
}
ParState::SendProtocol { mut io, protocol } => {
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Pending => {
*this.state = ParState::SendProtocol { io, protocol };
return Poll::Pending
},
}
let p = Protocol::try_from(protocol.as_ref())?;
if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() {
self.state = ParState::SendProtocol { io, protocol };
return Ok(Async::NotReady)
if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
return Poll::Ready(Err(From::from(err)));
}
debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p, self.version);
return Ok(Async::Ready((protocol, io)))
log::debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p, *this.version);
return Poll::Ready(Ok((protocol, io)))
}
ParState::Done => panic!("ParState::poll called after completion")
}
}

View File

@ -18,11 +18,9 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use bytes::{Bytes, BytesMut, Buf, BufMut};
use futures::{try_ready, Async, Poll, Sink, StartSend, Stream, AsyncSink};
use std::{io, u16};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint as uvi;
use bytes::{Bytes, BytesMut, Buf as _, BufMut as _};
use futures::{prelude::*, io::IoSlice};
use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16};
const MAX_LEN_BYTES: u16 = 2;
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
@ -34,9 +32,11 @@ const DEFAULT_BUFFER_SIZE: usize = 64;
/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
/// frame length). Frames mostly consist in a short protocol name, which is highly
/// unlikely to be more than 16KiB long.
#[pin_project::pin_project]
#[derive(Debug)]
pub struct LengthDelimited<R> {
/// The inner I/O resource.
#[pin]
inner: R,
/// Read buffer for a single incoming unsigned-varint length-delimited frame.
read_buffer: BytesMut,
@ -76,20 +76,7 @@ impl<R> LengthDelimited<R> {
}
}
/// Returns a reference to the underlying I/O stream.
pub fn inner_ref(&self) -> &R {
&self.inner
}
/// Returns a mutable reference to the underlying I/O stream.
///
/// > **Note**: Care should be taken to not tamper with the underlying stream of data
/// > coming in, as it may corrupt the stream of frames.
pub fn inner_mut(&mut self) -> &mut R {
&mut self.inner
}
/// Drops the `LengthDelimited` resource, yielding the underlying I/O stream
/// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream
/// together with the remaining write buffer containing the uvi-framed data
/// that has not yet been written to the underlying I/O stream.
///
@ -107,7 +94,7 @@ impl<R> LengthDelimited<R> {
(self.inner, self.write_buffer)
}
/// Converts the `LengthDelimited` into a `LengthDelimitedReader`, dropping the
/// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the
/// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying
/// I/O stream.
///
@ -121,25 +108,29 @@ impl<R> LengthDelimited<R> {
/// Writes all buffered frame data to the underlying I/O stream,
/// _without flushing it_.
///
/// After this method returns `Async::Ready`, the write buffer of frames
/// After this method returns `Poll::Ready`, the write buffer of frames
/// submitted to the `Sink` is guaranteed to be empty.
pub fn poll_write_buffer(&mut self) -> Poll<(), io::Error>
pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context)
-> Poll<Result<(), io::Error>>
where
R: AsyncWrite
{
while !self.write_buffer.is_empty() {
let n = try_ready!(self.inner.poll_write(&self.write_buffer));
let mut this = self.project();
if n == 0 {
return Err(io::Error::new(
while !this.write_buffer.is_empty() {
match this.inner.as_mut().poll_write(cx, &this.write_buffer) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(0)) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"Failed to write buffered frame."))
"Failed to write buffered frame.")))
}
Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
}
}
self.write_buffer.advance(n);
}
Ok(Async::Ready(()))
Poll::Ready(Ok(()))
}
}
@ -147,72 +138,67 @@ impl<R> Stream for LengthDelimited<R>
where
R: AsyncRead
{
type Item = Bytes;
type Error = io::Error;
type Item = Result<Bytes, io::Error>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
match &mut self.read_state {
match this.read_state {
ReadState::ReadLength { buf, pos } => {
match self.inner.read(&mut buf[*pos .. *pos + 1]) {
Ok(0) => {
match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) {
Poll::Ready(Ok(0)) => {
if *pos == 0 {
return Ok(Async::Ready(None));
return Poll::Ready(None);
} else {
return Err(io::ErrorKind::UnexpectedEof.into());
return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
}
}
Ok(n) => {
Poll::Ready(Ok(n)) => {
debug_assert_eq!(n, 1);
*pos += n;
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady);
}
Err(err) => {
return Err(err);
}
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
Poll::Pending => return Poll::Pending,
};
if (buf[*pos - 1] & 0x80) == 0 {
// MSB is not set, indicating the end of the length prefix.
let (len, _) = uvi::decode::u16(buf).map_err(|e| {
let (len, _) = unsigned_varint::decode::u16(buf)
.map_err(|e| {
log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;
if len >= 1 {
self.read_state = ReadState::ReadData { len, pos: 0 };
self.read_buffer.resize(len as usize, 0);
*this.read_state = ReadState::ReadData { len, pos: 0 };
this.read_buffer.resize(len as usize, 0);
} else {
debug_assert_eq!(len, 0);
self.read_state = ReadState::default();
return Ok(Async::Ready(Some(Bytes::new())));
*this.read_state = ReadState::default();
return Poll::Ready(Some(Ok(Bytes::new())));
}
} else if *pos == MAX_LEN_BYTES as usize {
// MSB signals more length bytes but we have already read the maximum.
// See the module documentation about the max frame len.
return Err(io::Error::new(
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame length exceeded"));
"Maximum frame length exceeded"))));
}
}
ReadState::ReadData { len, pos } => {
match self.inner.read(&mut self.read_buffer[*pos..]) {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
Ok(n) => *pos += n,
Err(err) =>
if err.kind() == io::ErrorKind::WouldBlock {
return Ok(Async::NotReady)
} else {
return Err(err)
}
match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
Poll::Ready(Ok(n)) => *pos += n,
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
};
if *pos == *len as usize {
// Finished reading the frame.
let frame = self.read_buffer.split_off(0).freeze();
self.read_state = ReadState::default();
return Ok(Async::Ready(Some(frame)));
let frame = this.read_buffer.split_off(0).freeze();
*this.read_state = ReadState::default();
return Poll::Ready(Some(Ok(frame)));
}
}
}
@ -220,58 +206,87 @@ where
}
}
impl<R> Sink for LengthDelimited<R>
impl<R> Sink<Bytes> for LengthDelimited<R>
where
R: AsyncWrite,
{
type SinkItem = Bytes;
type SinkError = io::Error;
type Error = io::Error;
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
// Use the maximum frame length also as a (soft) upper limit
// for the entire write buffer. The actual (hard) limit is thus
// implied to be roughly 2 * MAX_FRAME_SIZE.
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
self.poll_complete()?;
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
return Ok(AsyncSink::NotReady(msg))
}
if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
match self.as_mut().poll_write_buffer(cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
let len = msg.len() as u16;
if len > MAX_FRAME_SIZE {
debug_assert!(self.as_mut().project().write_buffer.is_empty());
}
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
let this = self.project();
let len = match u16::try_from(item.len()) {
Ok(len) if len <= MAX_FRAME_SIZE => len,
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame size exceeded."))
}
};
let mut uvi_buf = uvi::encode::u16_buffer();
let uvi_len = uvi::encode::u16(len, &mut uvi_buf);
self.write_buffer.reserve(len as usize + uvi_len.len());
self.write_buffer.put(uvi_len);
self.write_buffer.put(msg);
let mut uvi_buf = unsigned_varint::encode::u16_buffer();
let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
this.write_buffer.reserve(len as usize + uvi_len.len());
this.write_buffer.put(uvi_len);
this.write_buffer.put(item);
Ok(AsyncSink::Ready)
Ok(())
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
// Write all buffered frame data to the underlying I/O stream.
try_ready!(self.poll_write_buffer());
// Flush the underlying I/O stream.
try_ready!(self.inner.poll_flush());
return Ok(Async::Ready(()));
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
try_ready!(self.poll_complete());
Ok(self.inner.shutdown()?)
let this = self.project();
debug_assert!(this.write_buffer.is_empty());
// Flush the underlying I/O stream.
this.inner.poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
// Write all buffered frame data to the underlying I/O stream.
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
let this = self.project();
debug_assert!(this.write_buffer.is_empty());
// Close the underlying I/O stream.
this.inner.poll_close(cx)
}
}
/// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited
/// frames on an underlying I/O resource combined with direct `AsyncWrite` access.
#[pin_project::pin_project]
#[derive(Debug)]
pub struct LengthDelimitedReader<R> {
#[pin]
inner: LengthDelimited<R>
}
@ -284,58 +299,23 @@ impl<R> LengthDelimitedReader<R> {
/// # Panic
///
/// Will panic if called while there is data in the read or write buffer.
/// The read buffer is guaranteed to be empty whenever [`Stream::poll`] yields
/// a new `Message`. The write buffer is guaranteed to be empty whenever
/// [`LengthDelimited::poll_write_buffer`] yields [`Async::Ready`] or after
/// the [`Sink`] has been completely flushed via [`Sink::poll_complete`].
/// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`]
/// yield a new `Message`. The write buffer is guaranteed to be empty whenever
/// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after
/// the [`Sink`] has been completely flushed via [`Sink::poll_flush`].
pub fn into_inner(self) -> (R, BytesMut) {
self.inner.into_inner()
}
/// Returns a reference to the underlying I/O stream.
pub fn inner_ref(&self) -> &R {
self.inner.inner_ref()
}
/// Returns a mutable reference to the underlying I/O stream.
///
/// > **Note**: Care should be taken to not tamper with the underlying stream of data
/// > coming in, as it may corrupt the stream of frames.
pub fn inner_mut(&mut self) -> &mut R {
self.inner.inner_mut()
}
}
impl<R> Stream for LengthDelimitedReader<R>
where
R: AsyncRead
{
type Item = Bytes;
type Error = io::Error;
type Item = Result<Bytes, io::Error>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self.inner.poll()
}
}
impl<R> io::Write for LengthDelimitedReader<R>
where
R: AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
while !self.inner.write_buffer.is_empty() {
if self.inner.poll_write_buffer()?.is_not_ready() {
return Err(io::ErrorKind::WouldBlock.into())
}
}
self.inner_mut().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
match self.inner.poll_complete()? {
Async::Ready(()) => Ok(()),
Async::NotReady => Err(io::ErrorKind::WouldBlock.into())
}
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
@ -343,23 +323,62 @@ impl<R> AsyncWrite for LengthDelimitedReader<R>
where
R: AsyncWrite
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
try_ready!(self.inner.poll_complete());
self.inner_mut().shutdown()
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8])
-> Poll<Result<usize, io::Error>>
{
// `this` here designates the `LengthDelimited`.
let mut this = self.project().inner;
// We need to flush any data previously written with the `LengthDelimited`.
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
debug_assert!(this.write_buffer.is_empty());
this.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_close(cx)
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice])
-> Poll<Result<usize, io::Error>>
{
// `this` here designates the `LengthDelimited`.
let mut this = self.project().inner;
// We need to flush any data previously written with the `LengthDelimited`.
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
debug_assert!(this.write_buffer.is_empty());
this.project().inner.poll_write_vectored(cx, bufs)
}
}
#[cfg(test)]
mod tests {
use futures::{Future, Stream};
use crate::length_delimited::LengthDelimited;
use std::io::{Cursor, ErrorKind};
use async_std::net::{TcpListener, TcpStream};
use futures::{prelude::*, io::Cursor};
use quickcheck::*;
use std::io::ErrorKind;
#[test]
fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
}
@ -367,7 +386,7 @@ mod tests {
fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
}
@ -378,13 +397,10 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter());
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed
.into_future()
.map(|(m, _)| m)
.map_err(|_| ())
.wait()
.unwrap();
let mut framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(async move {
framed.next().await
}).unwrap();
assert_eq!(recved.unwrap(), frame);
}
@ -392,12 +408,10 @@ mod tests {
fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0));
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed
.into_future()
.map(|(m, _)| m)
.map_err(|(err, _)| err)
.wait();
let mut framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(async move {
framed.next().await.unwrap()
});
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
@ -410,7 +424,7 @@ mod tests {
fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
assert_eq!(
recved,
vec![
@ -427,7 +441,7 @@ mod tests {
fn unexpected_eof_in_len() {
let data = vec![0x89];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait();
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
@ -439,7 +453,7 @@ mod tests {
fn unexpected_eof_in_data() {
let data = vec![5];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait();
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
@ -451,12 +465,54 @@ mod tests {
fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait();
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}
#[test]
fn writing_reading() {
fn prop(frames: Vec<Vec<u8>>) -> TestResult {
async_std::task::block_on(async move {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap();
let expected_frames = frames.clone();
let server = async_std::task::spawn(async move {
let socket = listener.accept().await.unwrap().0;
let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
let mut buf = vec![0u8; 0];
for expected in expected_frames {
if expected.is_empty() {
continue;
}
if buf.len() < expected.len() {
buf.resize(expected.len(), 0);
}
let n = connec.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], &expected[..]);
}
});
let client = async_std::task::spawn(async move {
let socket = TcpStream::connect(&listener_addr).await.unwrap();
let mut connec = LengthDelimited::new(socket);
for frame in frames {
connec.send(From::from(frame)).await.unwrap();
}
});
server.await;
client.await;
});
TestResult::passed()
}
quickcheck(prop as fn(_) -> _)
}
}

View File

@ -77,26 +77,19 @@
//!
//! ```no_run
//! # fn main() {
//! use bytes::Bytes;
//! use async_std::net::TcpStream;
//! use multistream_select::{dialer_select_proto, Version};
//! use futures::{Future, Sink, Stream};
//! use tokio_tcp::TcpStream;
//! use tokio::runtime::current_thread::Runtime;
//! use futures::prelude::*;
//!
//! #[derive(Debug, Copy, Clone)]
//! enum MyProto { Echo, Hello }
//! async_std::task::block_on(async move {
//! let socket = TcpStream::connect("127.0.0.1:10333").await.unwrap();
//!
//! let client = TcpStream::connect(&"127.0.0.1:10333".parse().unwrap())
//! .from_err()
//! .and_then(move |io| {
//! let protos = vec![b"/echo/1.0.0", b"/echo/2.5.0"];
//! dialer_select_proto(io, protos, Version::V1)
//! })
//! .map(|(protocol, _io)| protocol);
//! let (protocol, _io) = dialer_select_proto(socket, protos, Version::V1).await.unwrap();
//!
//! let mut rt = Runtime::new().unwrap();
//! let protocol = rt.block_on(client).expect("failed to find a protocol");
//! println!("Negotiated protocol: {:?}", protocol);
//! // You can now use `_io` to communicate with the remote.
//! });
//! # }
//! ```
//!

View File

@ -21,13 +21,12 @@
//! Protocol negotiation strategies for the peer acting as the listener
//! in a multistream-select protocol negotiation.
use futures::prelude::*;
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
use log::{debug, warn};
use smallvec::SmallVec;
use std::{io, iter::FromIterator, mem, convert::TryFrom};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{Negotiated, NegotiationError};
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
use futures::prelude::*;
use smallvec::SmallVec;
use std::{convert::TryFrom as _, io, iter::FromIterator, mem, pin::Pin, task::{Context, Poll}};
/// Returns a `Future` that negotiates a protocol on the given I/O stream
/// for a peer acting as the _listener_ (or _responder_).
@ -49,7 +48,7 @@ where
match Protocol::try_from(n.as_ref()) {
Ok(p) => Some((n, p)),
Err(e) => {
warn!("Listener: Ignoring invalid protocol: {} due to {}",
log::warn!("Listener: Ignoring invalid protocol: {} due to {}",
String::from_utf8_lossy(n.as_ref()), e);
None
}
@ -64,6 +63,7 @@ where
/// The `Future` returned by [`listener_select_proto`] that performs a
/// multistream-select protocol negotiation on an underlying I/O stream.
#[pin_project::pin_project]
pub struct ListenerSelectFuture<R, N>
where
R: AsyncRead + AsyncWrite,
@ -94,64 +94,80 @@ where
impl<R, N> Future for ListenerSelectFuture<R, N>
where
R: AsyncRead + AsyncWrite,
// The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
// It also makes the implementation considerably easier to write.
R: AsyncRead + AsyncWrite + Unpin,
N: AsRef<[u8]> + Clone
{
type Item = (N, Negotiated<R>);
type Error = NegotiationError;
type Output = Result<(N, Negotiated<R>), NegotiationError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.project();
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.state, State::Done) {
match mem::replace(this.state, State::Done) {
State::RecvHeader { mut io } => {
match io.poll()? {
Async::Ready(Some(Message::Header(version))) => {
self.state = State::SendHeader { io, version }
match io.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(Message::Header(version)))) => {
*this.state = State::SendHeader { io, version }
}
Async::Ready(Some(_)) => {
return Err(ProtocolError::InvalidMessage.into())
}
Async::Ready(None) =>
return Err(NegotiationError::from(
Poll::Ready(Some(Ok(_))) => {
return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
},
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
Poll::Ready(None) =>
return Poll::Ready(Err(NegotiationError::from(
ProtocolError::IoError(
io::ErrorKind::UnexpectedEof.into()))),
Async::NotReady => {
self.state = State::RecvHeader { io };
return Ok(Async::NotReady)
io::ErrorKind::UnexpectedEof.into())))),
Poll::Pending => {
*this.state = State::RecvHeader { io };
return Poll::Pending
}
}
}
State::SendHeader { mut io, version } => {
if io.start_send(Message::Header(version))?.is_not_ready() {
return Ok(Async::NotReady)
match Pin::new(&mut io).poll_ready(cx) {
Poll::Pending => {
*this.state = State::SendHeader { io, version };
return Poll::Pending
},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
self.state = match version {
if let Err(err) = Pin::new(&mut io).start_send(Message::Header(version)) {
return Poll::Ready(Err(From::from(err)));
}
*this.state = match version {
Version::V1 => State::Flush { io },
Version::V1Lazy => State::RecvMessage { io },
}
}
State::RecvMessage { mut io } => {
let msg = match io.poll() {
Ok(Async::Ready(Some(msg))) => msg,
Ok(Async::Ready(None)) =>
return Err(NegotiationError::from(
let msg = match Pin::new(&mut io).poll_next(cx) {
Poll::Ready(Some(Ok(msg))) => msg,
Poll::Ready(None) =>
return Poll::Ready(Err(NegotiationError::from(
ProtocolError::IoError(
io::ErrorKind::UnexpectedEof.into()))),
Ok(Async::NotReady) => {
self.state = State::RecvMessage { io };
return Ok(Async::NotReady)
io::ErrorKind::UnexpectedEof.into())))),
Poll::Pending => {
*this.state = State::RecvMessage { io };
return Poll::Pending;
}
Err(e) => return Err(e.into())
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
};
match msg {
Message::ListProtocols => {
let supported = self.protocols.iter().map(|(_,p)| p).cloned().collect();
let supported = this.protocols.iter().map(|(_,p)| p).cloned().collect();
let message = Message::Protocols(supported);
self.state = State::SendMessage { io, message, protocol: None }
*this.state = State::SendMessage { io, message, protocol: None }
}
Message::Protocol(p) => {
let protocol = self.protocols.iter().find_map(|(name, proto)| {
let protocol = this.protocols.iter().find_map(|(name, proto)| {
if &p == proto {
Some(name.clone())
} else {
@ -160,45 +176,60 @@ where
});
let message = if protocol.is_some() {
debug!("Listener: confirming protocol: {}", p);
log::debug!("Listener: confirming protocol: {}", p);
Message::Protocol(p.clone())
} else {
debug!("Listener: rejecting protocol: {}",
log::debug!("Listener: rejecting protocol: {}",
String::from_utf8_lossy(p.as_ref()));
Message::NotAvailable
};
self.state = State::SendMessage { io, message, protocol };
*this.state = State::SendMessage { io, message, protocol };
}
_ => return Err(ProtocolError::InvalidMessage.into())
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
}
}
State::SendMessage { mut io, message, protocol } => {
if let AsyncSink::NotReady(message) = io.start_send(message)? {
self.state = State::SendMessage { io, message, protocol };
return Ok(Async::NotReady)
};
match Pin::new(&mut io).poll_ready(cx) {
Poll::Pending => {
*this.state = State::SendMessage { io, message, protocol };
return Poll::Pending
},
Poll::Ready(Ok(())) => {},
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
if let Err(err) = Pin::new(&mut io).start_send(message) {
return Poll::Ready(Err(From::from(err)));
}
// If a protocol has been selected, finish negotiation.
// Otherwise flush the sink and expect to receive another
// message.
self.state = match protocol {
*this.state = match protocol {
Some(protocol) => {
debug!("Listener: sent confirmed protocol: {}",
log::debug!("Listener: sent confirmed protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
let (io, remaining) = io.into_inner();
let io = Negotiated::completed(io, remaining);
return Ok(Async::Ready((protocol, io)))
return Poll::Ready(Ok((protocol, io)));
}
None => State::Flush { io }
};
}
State::Flush { mut io } => {
if io.poll_complete()?.is_not_ready() {
self.state = State::Flush { io };
return Ok(Async::NotReady)
match Pin::new(&mut io).poll_flush(cx) {
Poll::Pending => {
*this.state = State::Flush { io };
return Poll::Pending
},
Poll::Ready(Ok(())) => *this.state = State::RecvMessage { io },
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
self.state = State::RecvMessage { io }
}
State::Done => panic!("State::poll called after completion")
}
}

View File

@ -18,12 +18,12 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use bytes::{BytesMut, Buf};
use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError};
use futures::{prelude::*, Async, try_ready};
use log::debug;
use tokio_io::{AsyncRead, AsyncWrite};
use std::{mem, io, fmt, error::Error};
use bytes::{BytesMut, Buf};
use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
use pin_project::{pin_project, project};
use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
/// An I/O stream that has settled on an (application-layer) protocol to use.
///
@ -36,28 +36,40 @@ use std::{mem, io, fmt, error::Error};
///
/// Reading from a `Negotiated` I/O stream that still has pending negotiation
/// protocol data to send implicitly triggers flushing of all yet unsent data.
#[pin_project]
#[derive(Debug)]
pub struct Negotiated<TInner> {
#[pin]
state: State<TInner>
}
/// A `Future` that waits on the completion of protocol negotiation.
#[derive(Debug)]
pub struct NegotiatedComplete<TInner> {
inner: Option<Negotiated<TInner>>
inner: Option<Negotiated<TInner>>,
}
impl<TInner: AsyncRead + AsyncWrite> Future for NegotiatedComplete<TInner> {
type Item = Negotiated<TInner>;
type Error = NegotiationError;
impl<TInner> Future for NegotiatedComplete<TInner>
where
// `Unpin` is required not because of implementation details but because we produce the
// `Negotiated` as the output of the future.
TInner: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<Negotiated<TInner>, NegotiationError>;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
if io.poll()?.is_not_ready() {
match Negotiated::poll(Pin::new(&mut io), cx) {
Poll::Pending => {
self.inner = Some(io);
return Ok(Async::NotReady)
return Poll::Pending
},
Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
Poll::Ready(Err(err)) => {
self.inner = Some(io);
return Poll::Ready(Err(err));
}
}
return Ok(Async::Ready(io))
}
}
@ -75,66 +87,67 @@ impl<TInner> Negotiated<TInner> {
}
/// Polls the `Negotiated` for completion.
fn poll(&mut self) -> Poll<(), NegotiationError>
#[project]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), NegotiationError>>
where
TInner: AsyncRead + AsyncWrite
TInner: AsyncRead + AsyncWrite + Unpin
{
// Flush any pending negotiation data.
match self.poll_flush() {
Ok(Async::Ready(())) => {},
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => {
match self.as_mut().poll_flush(cx) {
Poll::Ready(Ok(())) => {},
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
// If the remote closed the stream, it is important to still
// continue reading the data that was sent, if any.
if e.kind() != io::ErrorKind::WriteZero {
return Err(e.into())
return Poll::Ready(Err(e.into()))
}
}
}
if let State::Completed { remaining, .. } = &mut self.state {
let _ = remaining.split_to(remaining.len()); // Drop remaining data flushed above.
return Ok(Async::Ready(()))
let mut this = self.project();
#[project]
match this.state.as_mut().project() {
State::Completed { remaining, .. } => {
debug_assert!(remaining.is_empty());
return Poll::Ready(Ok(()))
}
_ => {}
}
// Read outstanding protocol negotiation messages.
loop {
match mem::replace(&mut self.state, State::Invalid) {
match mem::replace(&mut *this.state, State::Invalid) {
State::Expecting { mut io, protocol, version } => {
let msg = match io.poll() {
Ok(Async::Ready(Some(msg))) => msg,
Ok(Async::NotReady) => {
self.state = State::Expecting { io, protocol, version };
return Ok(Async::NotReady)
}
Ok(Async::Ready(None)) => {
self.state = State::Expecting { io, protocol, version };
return Err(ProtocolError::IoError(
io::ErrorKind::UnexpectedEof.into()).into())
}
Err(err) => {
self.state = State::Expecting { io, protocol, version };
return Err(err.into())
let msg = match Pin::new(&mut io).poll_next(cx)? {
Poll::Ready(Some(msg)) => msg,
Poll::Pending => {
*this.state = State::Expecting { io, protocol, version };
return Poll::Pending
},
Poll::Ready(None) => {
return Poll::Ready(Err(ProtocolError::IoError(
io::ErrorKind::UnexpectedEof.into()).into()));
}
};
if let Message::Header(v) = &msg {
if v == &version {
self.state = State::Expecting { io, protocol, version };
if *v == version {
continue
}
}
if let Message::Protocol(p) = &msg {
if p.as_ref() == protocol.as_ref() {
debug!("Negotiated: Received confirmation for protocol: {}", p);
log::debug!("Negotiated: Received confirmation for protocol: {}", p);
let (io, remaining) = io.into_inner();
self.state = State::Completed { io, remaining };
return Ok(Async::Ready(()))
*this.state = State::Completed { io, remaining };
return Poll::Ready(Ok(()));
}
}
return Err(NegotiationError::Failed)
return Poll::Ready(Err(NegotiationError::Failed));
}
_ => panic!("Negotiated: Invalid state")
@ -142,7 +155,7 @@ impl<TInner> Negotiated<TInner> {
}
}
/// Returns a `NegotiatedComplete` future that waits for protocol
/// Returns a [`NegotiatedComplete`] future that waits for protocol
/// negotiation to complete.
pub fn complete(self) -> NegotiatedComplete<TInner> {
NegotiatedComplete { inner: Some(self) }
@ -150,12 +163,14 @@ impl<TInner> Negotiated<TInner> {
}
/// The states of a `Negotiated` I/O stream.
#[pin_project]
#[derive(Debug)]
enum State<R> {
/// In this state, a `Negotiated` is still expecting to
/// receive confirmation of the protocol it as settled on.
Expecting {
/// The underlying I/O stream.
#[pin]
io: MessageReader<R>,
/// The expected protocol (i.e. name and version).
protocol: Protocol,
@ -167,113 +182,157 @@ enum State<R> {
/// only be pending the sending of the final acknowledgement,
/// which is prepended to / combined with the next write for
/// efficiency.
Completed { io: R, remaining: BytesMut },
Completed { #[pin] io: R, remaining: BytesMut },
/// Temporary state while moving the `io` resource from
/// `Expecting` to `Completed`.
Invalid,
}
impl<R> io::Read for Negotiated<R>
impl<TInner> AsyncRead for Negotiated<TInner>
where
R: AsyncRead + AsyncWrite
TInner: AsyncRead + AsyncWrite + Unpin
{
#[project]
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8])
-> Poll<Result<usize, io::Error>>
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop {
if let State::Completed { io, remaining } = &mut self.state {
#[project]
match self.as_mut().project().state.project() {
State::Completed { io, remaining } => {
// If protocol negotiation is complete and there is no
// remaining data to be flushed, commence with reading.
if remaining.is_empty() {
return io.read(buf)
return io.poll_read(cx, buf)
}
},
_ => {}
}
// Poll the `Negotiated`, driving protocol negotiation to completion,
// including flushing of any remaining data.
let result = self.poll();
// There is still remaining data to be sent before data relating
// to the negotiated protocol can be read.
if let Ok(Async::NotReady) = result {
return Err(io::ErrorKind::WouldBlock.into())
}
if let Err(err) = result {
return Err(err.into())
}
match self.as_mut().poll(cx) {
Poll::Ready(Ok(())) => {},
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
}
}
impl<TInner> AsyncRead for Negotiated<TInner>
where
TInner: AsyncRead + AsyncWrite
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
// TODO: implement once method is stabilized in the futures crate
/*unsafe fn initializer(&self) -> Initializer {
match &self.state {
State::Completed { io, .. } =>
io.prepare_uninitialized_buffer(buf),
State::Expecting { io, .. } =>
io.inner_ref().prepare_uninitialized_buffer(buf),
State::Invalid => panic!("Negotiated: Invalid state")
}
}
State::Completed { io, .. } => io.initializer(),
State::Expecting { io, .. } => io.inner_ref().initializer(),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}*/
impl<TInner> io::Write for Negotiated<TInner>
where
TInner: AsyncWrite
#[project]
fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context, bufs: &mut [IoSliceMut])
-> Poll<Result<usize, io::Error>>
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match &mut self.state {
State::Completed { io, ref mut remaining } => {
while !remaining.is_empty() {
let n = io.write(&remaining)?;
if n == 0 {
return Err(io::ErrorKind::WriteZero.into())
loop {
#[project]
match self.as_mut().project().state.project() {
State::Completed { io, remaining } => {
// If protocol negotiation is complete and there is no
// remaining data to be flushed, commence with reading.
if remaining.is_empty() {
return io.poll_read_vectored(cx, bufs)
}
remaining.advance(n);
}
io.write(buf)
},
State::Expecting { io, .. } => io.write(buf),
State::Invalid => panic!("Negotiated: Invalid state")
}
_ => {}
}
fn flush(&mut self) -> io::Result<()> {
match &mut self.state {
State::Completed { io, ref mut remaining } => {
while !remaining.is_empty() {
let n = io.write(remaining)?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Failed to write remaining buffer."))
// Poll the `Negotiated`, driving protocol negotiation to completion,
// including flushing of any remaining data.
match self.as_mut().poll(cx) {
Poll::Ready(Ok(())) => {},
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
remaining.advance(n);
}
io.flush()
},
State::Expecting { io, .. } => io.flush(),
State::Invalid => panic!("Negotiated: Invalid state")
}
}
}
impl<TInner> AsyncWrite for Negotiated<TInner>
where
TInner: AsyncWrite + AsyncRead
TInner: AsyncWrite + AsyncRead + Unpin
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
#[project]
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
#[project]
match self.project().state.project() {
State::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_write(cx, buf)
},
State::Expecting { io, .. } => io.poll_write(cx, buf),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
#[project]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
#[project]
match self.project().state.project() {
State::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_flush(cx)
},
State::Expecting { io, .. } => io.poll_flush(cx),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
#[project]
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
// Ensure all data has been flushed and expected negotiation messages
// have been received.
try_ready!(self.poll().map_err(Into::<io::Error>::into));
ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
// Continue with the shutdown of the underlying I/O stream.
match &mut self.state {
State::Completed { io, .. } => io.shutdown(),
State::Expecting { io, .. } => io.shutdown(),
State::Invalid => panic!("Negotiated: Invalid state")
#[project]
match self.project().state.project() {
State::Completed { io, .. } => io.poll_close(cx),
State::Expecting { io, .. } => io.poll_close(cx),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
#[project]
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice])
-> Poll<Result<usize, io::Error>>
{
#[project]
match self.project().state.project() {
State::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_write_vectored(cx, bufs)
},
State::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
}
@ -300,12 +359,12 @@ impl From<io::Error> for NegotiationError {
}
}
impl Into<io::Error> for NegotiationError {
fn into(self) -> io::Error {
if let NegotiationError::ProtocolError(e) = self {
impl From<NegotiationError> for io::Error {
fn from(err: NegotiationError) -> io::Error {
if let NegotiationError::ProtocolError(e) = err {
return e.into()
}
io::Error::new(io::ErrorKind::Other, self)
io::Error::new(io::ErrorKind::Other, err)
}
}
@ -333,27 +392,33 @@ impl fmt::Display for NegotiationError {
mod tests {
use super::*;
use quickcheck::*;
use std::io::Write;
use std::{io::Write, task::Poll};
/// An I/O resource with a fixed write capacity (total and per write op).
struct Capped { buf: Vec<u8>, step: usize }
impl io::Write for Capped {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.buf.len() + buf.len() > self.buf.capacity() {
return Err(io::ErrorKind::WriteZero.into())
}
self.buf.write(&buf[.. usize::min(self.step, buf.len())])
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
impl AsyncRead for Capped {
fn poll_read(self: Pin<&mut Self>, _: &mut Context, _: &mut [u8]) -> Poll<Result<usize, io::Error>> {
unreachable!()
}
}
impl AsyncWrite for Capped {
fn shutdown(&mut self) -> Poll<(), io::Error> {
Ok(().into())
fn poll_write(mut self: Pin<&mut Self>, _: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
if self.buf.len() + buf.len() > self.buf.capacity() {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
let len = usize::min(self.step, buf.len());
let n = Write::write(&mut self.buf, &buf[.. len]).unwrap();
Poll::Ready(Ok(n))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
@ -369,7 +434,7 @@ mod tests {
loop {
// Write until `new` has been fully written or the capped buffer runs
// over capacity and yields WriteZero.
match io.write(&new[written..]) {
match future::poll_fn(|cx| Pin::new(&mut io).poll_write(cx, &new[written..])).now_or_never().unwrap() {
Ok(n) =>
if let State::Completed { remaining, .. } = &io.state {
assert!(remaining.is_empty());
@ -388,7 +453,7 @@ mod tests {
return TestResult::failed()
}
}
Err(e) => panic!("Unexpected error: {:?}", e)
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
}

View File

@ -25,12 +25,11 @@
//! `Stream` and `Sink` implementations of `MessageIO` and
//! `MessageReader`.
use bytes::{Bytes, BytesMut, BufMut};
use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
use futures::{prelude::*, try_ready};
use log::trace;
use std::{io, fmt, error::Error, convert::TryFrom};
use tokio_io::{AsyncRead, AsyncWrite};
use bytes::{Bytes, BytesMut, BufMut};
use futures::{prelude::*, io::IoSlice, ready};
use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}};
use unsigned_varint as uvi;
/// The maximum number of supported protocols that can be processed.
@ -264,7 +263,9 @@ impl Message {
}
/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s.
#[pin_project::pin_project]
pub struct MessageIO<R> {
#[pin]
inner: LengthDelimited<R>,
}
@ -277,8 +278,8 @@ impl<R> MessageIO<R> {
Self { inner: LengthDelimited::new(inner) }
}
/// Converts the `MessageIO` into a `MessageReader`, dropping the
/// `Message`-oriented `Sink` in favour of direct `AsyncWrite` access
/// Converts the [`MessageIO`] into a [`MessageReader`], dropping the
/// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access
/// to the underlying I/O stream.
///
/// This is typically done if further negotiation messages are expected to be
@ -288,7 +289,7 @@ impl<R> MessageIO<R> {
MessageReader { inner: self.inner.into_reader() }
}
/// Drops the `MessageIO` resource, yielding the underlying I/O stream
/// Drops the [`MessageIO`] resource, yielding the underlying I/O stream
/// together with the remaining write buffer containing the protocol
/// negotiation frame data that has not yet been written to the I/O stream.
///
@ -309,28 +310,28 @@ impl<R> MessageIO<R> {
}
}
impl<R> Sink for MessageIO<R>
impl<R> Sink<Message> for MessageIO<R>
where
R: AsyncWrite,
{
type SinkItem = Message;
type SinkError = ProtocolError;
type Error = ProtocolError;
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx).map_err(From::from)
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
let mut buf = BytesMut::new();
msg.encode(&mut buf)?;
match self.inner.start_send(buf.freeze())? {
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(msg)),
AsyncSink::Ready => Ok(AsyncSink::Ready),
}
item.encode(&mut buf)?;
self.project().inner.start_send(buf.freeze()).map_err(From::from)
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.poll_complete()?)
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx).map_err(From::from)
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.close()?)
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx).map_err(From::from)
}
}
@ -338,18 +339,24 @@ impl<R> Stream for MessageIO<R>
where
R: AsyncRead
{
type Item = Message;
type Error = ProtocolError;
type Item = Result<Message, ProtocolError>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
poll_stream(&mut self.inner)
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match poll_stream(self.project().inner, cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(From::from(err)))),
}
}
}
/// A `MessageReader` implements a `Stream` of `Message`s on an underlying
/// I/O resource combined with direct `AsyncWrite` access.
#[pin_project::pin_project]
#[derive(Debug)]
pub struct MessageReader<R> {
#[pin]
inner: LengthDelimitedReader<R>
}
@ -373,35 +380,16 @@ impl<R> MessageReader<R> {
pub fn into_inner(self) -> (R, BytesMut) {
self.inner.into_inner()
}
/// Returns a reference to the underlying I/O stream.
pub fn inner_ref(&self) -> &R {
self.inner.inner_ref()
}
}
impl<R> Stream for MessageReader<R>
where
R: AsyncRead
{
type Item = Message;
type Error = ProtocolError;
type Item = Result<Message, ProtocolError>;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
poll_stream(&mut self.inner)
}
}
impl<R> io::Write for MessageReader<R>
where
R: AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
poll_stream(self.project().inner, cx)
}
}
@ -409,24 +397,39 @@ impl<TInner> AsyncWrite for MessageReader<TInner>
where
TInner: AsyncWrite
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.inner.shutdown()
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_close(cx)
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
}
fn poll_stream<S>(stream: &mut S) -> Poll<Option<Message>, ProtocolError>
fn poll_stream<S>(stream: Pin<&mut S>, cx: &mut Context) -> Poll<Option<Result<Message, ProtocolError>>>
where
S: Stream<Item = Bytes, Error = io::Error>,
S: Stream<Item = Result<Bytes, io::Error>>,
{
let msg = if let Some(msg) = try_ready!(stream.poll()) {
Message::decode(msg)?
let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
match Message::decode(msg) {
Ok(m) => m,
Err(err) => return Poll::Ready(Some(Err(err))),
}
} else {
return Ok(Async::Ready(None))
return Poll::Ready(None)
};
trace!("Received message: {:?}", msg);
log::trace!("Received message: {:?}", msg);
Ok(Async::Ready(Some(msg)))
Poll::Ready(Some(Ok(msg)))
}
/// A protocol error.

View File

@ -25,164 +25,156 @@
use crate::{Version, NegotiationError};
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
use crate::{dialer_select_proto, listener_select_proto};
use async_std::net::{TcpListener, TcpStream};
use futures::prelude::*;
use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream};
use tokio_io::io as nio;
#[test]
fn select_proto_basic() {
fn run(version: Version) {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
async fn run(version: Version) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap();
let server = listener
.incoming()
.into_future()
.map(|s| s.0.unwrap())
.map_err(|(e, _)| e.into())
.and_then(move |connec| {
let server = async_std::task::spawn(async move {
let connec = listener.accept().await.unwrap().0;
let protos = vec![b"/proto1", b"/proto2"];
listener_select_proto(connec, protos)
})
.and_then(|(proto, io)| {
nio::write_all(io, b"pong").from_err().map(move |_| proto)
let (proto, mut io) = listener_select_proto(connec, protos).await.unwrap();
assert_eq!(proto, b"/proto2");
let mut out = vec![0; 32];
let n = io.read(&mut out).await.unwrap();
out.truncate(n);
assert_eq!(out, b"ping");
io.write_all(b"pong").await.unwrap();
io.flush().await.unwrap();
});
let client = TcpStream::connect(&listener_addr)
.from_err()
.and_then(move |connec| {
let client = async_std::task::spawn(async move {
let connec = TcpStream::connect(&listener_addr).await.unwrap();
let protos = vec![b"/proto3", b"/proto2"];
dialer_select_proto(connec, protos, version)
})
.and_then(|(proto, io)| {
nio::write_all(io, b"ping").from_err().map(move |(io, _)| (proto, io))
})
.and_then(|(proto, io)| {
nio::read_exact(io, [0; 4]).from_err().map(move |(_, msg)| {
assert_eq!(&msg, b"pong");
proto
})
let (proto, mut io) = dialer_select_proto(connec, protos.into_iter(), version)
.await.unwrap();
assert_eq!(proto, b"/proto2");
io.write_all(b"ping").await.unwrap();
io.flush().await.unwrap();
let mut out = vec![0; 32];
let n = io.read(&mut out).await.unwrap();
out.truncate(n);
assert_eq!(out, b"pong");
});
let mut rt = Runtime::new().unwrap();
let (dialer_chosen, listener_chosen) =
rt.block_on(client.join(server)).unwrap();
assert_eq!(dialer_chosen, b"/proto2");
assert_eq!(listener_chosen, b"/proto2");
server.await;
client.await;
}
run(Version::V1);
run(Version::V1Lazy);
async_std::task::block_on(run(Version::V1));
async_std::task::block_on(run(Version::V1Lazy));
}
#[test]
fn no_protocol_found() {
fn run(version: Version) {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
async fn run(version: Version) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap();
let server = listener
.incoming()
.into_future()
.map(|s| s.0.unwrap())
.map_err(|(e, _)| e.into())
.and_then(move |connec| {
let server = async_std::task::spawn(async move {
let connec = listener.accept().await.unwrap().0;
let protos = vec![b"/proto1", b"/proto2"];
listener_select_proto(connec, protos)
})
.and_then(|(proto, io)| io.complete().map(move |_| proto));
let io = match listener_select_proto(connec, protos).await {
Ok((_, io)) => io,
// We don't explicitly check for `Failed` because the client might close the connection when it
// realizes that we have no protocol in common.
Err(_) => return,
};
match io.complete().await {
Err(NegotiationError::Failed) => {},
_ => panic!(),
}
});
let client = TcpStream::connect(&listener_addr)
.from_err()
.and_then(move |connec| {
let client = async_std::task::spawn(async move {
let connec = TcpStream::connect(&listener_addr).await.unwrap();
let protos = vec![b"/proto3", b"/proto4"];
dialer_select_proto(connec, protos, version)
})
.and_then(|(proto, io)| io.complete().map(move |_| proto));
let mut rt = Runtime::new().unwrap();
match rt.block_on(client.join(server)) {
Err(NegotiationError::Failed) => (),
e => panic!("{:?}", e),
let io = match dialer_select_proto(connec, protos.into_iter(), version).await {
Err(NegotiationError::Failed) => return,
Ok((_, io)) => io,
Err(_) => panic!()
};
match io.complete().await {
Err(NegotiationError::Failed) => {},
_ => panic!(),
}
});
server.await;
client.await;
}
run(Version::V1);
run(Version::V1Lazy);
async_std::task::block_on(run(Version::V1));
async_std::task::block_on(run(Version::V1Lazy));
}
#[test]
fn select_proto_parallel() {
fn run(version: Version) {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
async fn run(version: Version) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap();
let server = listener
.incoming()
.into_future()
.map(|s| s.0.unwrap())
.map_err(|(e, _)| e.into())
.and_then(move |connec| {
let server = async_std::task::spawn(async move {
let connec = listener.accept().await.unwrap().0;
let protos = vec![b"/proto1", b"/proto2"];
listener_select_proto(connec, protos)
})
.and_then(|(proto, io)| io.complete().map(move |_| proto));
let (proto, io) = listener_select_proto(connec, protos).await.unwrap();
assert_eq!(proto, b"/proto2");
io.complete().await.unwrap();
});
let client = TcpStream::connect(&listener_addr)
.from_err()
.and_then(move |connec| {
let client = async_std::task::spawn(async move {
let connec = TcpStream::connect(&listener_addr).await.unwrap();
let protos = vec![b"/proto3", b"/proto2"];
dialer_select_proto_parallel(connec, protos.into_iter(), version)
})
.and_then(|(proto, io)| io.complete().map(move |_| proto));
let (proto, io) = dialer_select_proto_parallel(connec, protos.into_iter(), version)
.await.unwrap();
assert_eq!(proto, b"/proto2");
io.complete().await.unwrap();
});
let mut rt = Runtime::new().unwrap();
let (dialer_chosen, listener_chosen) =
rt.block_on(client.join(server)).unwrap();
assert_eq!(dialer_chosen, b"/proto2");
assert_eq!(listener_chosen, b"/proto2");
server.await;
client.await;
}
run(Version::V1);
run(Version::V1Lazy);
async_std::task::block_on(run(Version::V1));
async_std::task::block_on(run(Version::V1Lazy));
}
#[test]
fn select_proto_serial() {
fn run(version: Version) {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
async fn run(version: Version) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap();
let server = listener
.incoming()
.into_future()
.map(|s| s.0.unwrap())
.map_err(|(e, _)| e.into())
.and_then(move |connec| {
let server = async_std::task::spawn(async move {
let connec = listener.accept().await.unwrap().0;
let protos = vec![b"/proto1", b"/proto2"];
listener_select_proto(connec, protos)
})
.and_then(|(proto, io)| io.complete().map(move |_| proto));
let (proto, io) = listener_select_proto(connec, protos).await.unwrap();
assert_eq!(proto, b"/proto2");
io.complete().await.unwrap();
});
let client = TcpStream::connect(&listener_addr)
.from_err()
.and_then(move |connec| {
let client = async_std::task::spawn(async move {
let connec = TcpStream::connect(&listener_addr).await.unwrap();
let protos = vec![b"/proto3", b"/proto2"];
dialer_select_proto_serial(connec, protos.into_iter(), version)
})
.and_then(|(proto, io)| io.complete().map(move |_| proto));
let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version)
.await.unwrap();
assert_eq!(proto, b"/proto2");
io.complete().await.unwrap();
});
let mut rt = Runtime::new().unwrap();
let (dialer_chosen, listener_chosen) =
rt.block_on(client.join(server)).unwrap();
assert_eq!(dialer_chosen, b"/proto2");
assert_eq!(listener_chosen, b"/proto2");
server.await;
client.await;
}
run(Version::V1);
run(Version::V1Lazy);
async_std::task::block_on(run(Version::V1));
async_std::task::block_on(run(Version::V1Lazy));
}