mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-06-13 18:11:22 +00:00
Update multistream-select to stable futures (#1484)
* Update multistream-select to stable futures * Fix intradoc links
This commit is contained in:
@ -15,7 +15,7 @@ bs58 = "0.3.0"
|
|||||||
ed25519-dalek = "1.0.0-pre.3"
|
ed25519-dalek = "1.0.0-pre.3"
|
||||||
either = "1.5"
|
either = "1.5"
|
||||||
fnv = "1.0"
|
fnv = "1.0"
|
||||||
futures = { version = "0.3.1", features = ["compat", "io-compat", "executor", "thread-pool"] }
|
futures = { version = "0.3.1", features = ["executor", "thread-pool"] }
|
||||||
futures-timer = "3"
|
futures-timer = "3"
|
||||||
lazy_static = "1.2"
|
lazy_static = "1.2"
|
||||||
libsecp256k1 = { version = "0.3.1", optional = true }
|
libsecp256k1 = { version = "0.3.1", optional = true }
|
||||||
|
@ -41,7 +41,7 @@ mod keys_proto {
|
|||||||
|
|
||||||
/// Multi-address re-export.
|
/// Multi-address re-export.
|
||||||
pub use multiaddr;
|
pub use multiaddr;
|
||||||
pub type Negotiated<T> = futures::compat::Compat01As03<multistream_select::Negotiated<futures::compat::Compat<T>>>;
|
pub type Negotiated<T> = multistream_select::Negotiated<T>;
|
||||||
|
|
||||||
mod peer_id;
|
mod peer_id;
|
||||||
mod translation;
|
mod translation;
|
||||||
|
@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
use crate::{ConnectedPoint, Negotiated};
|
use crate::{ConnectedPoint, Negotiated};
|
||||||
use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName};
|
use crate::upgrade::{InboundUpgrade, OutboundUpgrade, UpgradeError, ProtocolName};
|
||||||
use futures::{future::Either, prelude::*, compat::Compat, compat::Compat01As03, compat::Future01CompatExt};
|
use futures::{future::Either, prelude::*};
|
||||||
use log::debug;
|
use log::debug;
|
||||||
use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture};
|
use multistream_select::{self, DialerSelectFuture, ListenerSelectFuture};
|
||||||
use std::{iter, mem, pin::Pin, task::Context, task::Poll};
|
use std::{iter, mem, pin::Pin, task::Context, task::Poll};
|
||||||
@ -48,7 +48,7 @@ where
|
|||||||
U: InboundUpgrade<Negotiated<C>>,
|
U: InboundUpgrade<Negotiated<C>>,
|
||||||
{
|
{
|
||||||
let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>);
|
let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>);
|
||||||
let future = multistream_select::listener_select_proto(Compat::new(conn), iter).compat();
|
let future = multistream_select::listener_select_proto(conn, iter);
|
||||||
InboundUpgradeApply {
|
InboundUpgradeApply {
|
||||||
inner: InboundUpgradeApplyState::Init { future, upgrade: up }
|
inner: InboundUpgradeApplyState::Init { future, upgrade: up }
|
||||||
}
|
}
|
||||||
@ -61,7 +61,7 @@ where
|
|||||||
U: OutboundUpgrade<Negotiated<C>>
|
U: OutboundUpgrade<Negotiated<C>>
|
||||||
{
|
{
|
||||||
let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>);
|
let iter = up.protocol_info().into_iter().map(NameWrap as fn(_) -> NameWrap<_>);
|
||||||
let future = multistream_select::dialer_select_proto(Compat::new(conn), iter, v).compat();
|
let future = multistream_select::dialer_select_proto(conn, iter, v);
|
||||||
OutboundUpgradeApply {
|
OutboundUpgradeApply {
|
||||||
inner: OutboundUpgradeApplyState::Init { future, upgrade: up }
|
inner: OutboundUpgradeApplyState::Init { future, upgrade: up }
|
||||||
}
|
}
|
||||||
@ -82,7 +82,7 @@ where
|
|||||||
U: InboundUpgrade<Negotiated<C>>,
|
U: InboundUpgrade<Negotiated<C>>,
|
||||||
{
|
{
|
||||||
Init {
|
Init {
|
||||||
future: Compat01As03<ListenerSelectFuture<Compat<C>, NameWrap<U::Info>>>,
|
future: ListenerSelectFuture<C, NameWrap<U::Info>>,
|
||||||
upgrade: U,
|
upgrade: U,
|
||||||
},
|
},
|
||||||
Upgrade {
|
Upgrade {
|
||||||
@ -117,7 +117,7 @@ where
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
self.inner = InboundUpgradeApplyState::Upgrade {
|
self.inner = InboundUpgradeApplyState::Upgrade {
|
||||||
future: Box::pin(upgrade.upgrade_inbound(Compat01As03::new(io), info.0))
|
future: Box::pin(upgrade.upgrade_inbound(io, info.0))
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
InboundUpgradeApplyState::Upgrade { mut future } => {
|
InboundUpgradeApplyState::Upgrade { mut future } => {
|
||||||
@ -158,7 +158,7 @@ where
|
|||||||
U: OutboundUpgrade<Negotiated<C>>
|
U: OutboundUpgrade<Negotiated<C>>
|
||||||
{
|
{
|
||||||
Init {
|
Init {
|
||||||
future: Compat01As03<DialerSelectFuture<Compat<C>, NameWrapIter<<U::InfoIter as IntoIterator>::IntoIter>>>,
|
future: DialerSelectFuture<C, NameWrapIter<<U::InfoIter as IntoIterator>::IntoIter>>,
|
||||||
upgrade: U
|
upgrade: U
|
||||||
},
|
},
|
||||||
Upgrade {
|
Upgrade {
|
||||||
@ -193,7 +193,7 @@ where
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
self.inner = OutboundUpgradeApplyState::Upgrade {
|
self.inner = OutboundUpgradeApplyState::Upgrade {
|
||||||
future: Box::pin(upgrade.upgrade_outbound(Compat01As03::new(connection), info.0))
|
future: Box::pin(upgrade.upgrade_outbound(connection, info.0))
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
OutboundUpgradeApplyState::Upgrade { mut future } => {
|
OutboundUpgradeApplyState::Upgrade { mut future } => {
|
||||||
|
@ -11,14 +11,14 @@ edition = "2018"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bytes = "0.5"
|
bytes = "0.5"
|
||||||
futures = "0.1"
|
futures = "0.3"
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
|
pin-project = "0.4.8"
|
||||||
smallvec = "1.0"
|
smallvec = "1.0"
|
||||||
tokio-io = "0.1"
|
unsigned-varint = "0.3.2"
|
||||||
unsigned-varint = "0.3"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = "0.1"
|
async-std = "1.2"
|
||||||
tokio-tcp = "0.1"
|
|
||||||
quickcheck = "0.9.0"
|
quickcheck = "0.9.0"
|
||||||
rand = "0.7.2"
|
rand = "0.7.2"
|
||||||
|
rw-stream-sink = "0.2.1"
|
||||||
|
@ -20,12 +20,11 @@
|
|||||||
|
|
||||||
//! Protocol negotiation strategies for the peer acting as the dialer.
|
//! Protocol negotiation strategies for the peer acting as the dialer.
|
||||||
|
|
||||||
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
|
|
||||||
use futures::{future::Either, prelude::*};
|
|
||||||
use log::debug;
|
|
||||||
use std::{io, iter, mem, convert::TryFrom};
|
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
|
||||||
use crate::{Negotiated, NegotiationError};
|
use crate::{Negotiated, NegotiationError};
|
||||||
|
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
|
||||||
|
|
||||||
|
use futures::{future::Either, prelude::*};
|
||||||
|
use std::{convert::TryFrom as _, io, iter, mem, pin::Pin, task::{Context, Poll}};
|
||||||
|
|
||||||
/// Returns a `Future` that negotiates a protocol on the given I/O stream
|
/// Returns a `Future` that negotiates a protocol on the given I/O stream
|
||||||
/// for a peer acting as the _dialer_ (or _initiator_).
|
/// for a peer acting as the _dialer_ (or _initiator_).
|
||||||
@ -60,9 +59,9 @@ where
|
|||||||
let iter = protocols.into_iter();
|
let iter = protocols.into_iter();
|
||||||
// We choose between the "serial" and "parallel" strategies based on the number of protocols.
|
// We choose between the "serial" and "parallel" strategies based on the number of protocols.
|
||||||
if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
|
if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
|
||||||
Either::A(dialer_select_proto_serial(inner, iter, version))
|
Either::Left(dialer_select_proto_serial(inner, iter, version))
|
||||||
} else {
|
} else {
|
||||||
Either::B(dialer_select_proto_parallel(inner, iter, version))
|
Either::Right(dialer_select_proto_parallel(inner, iter, version))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,6 +128,7 @@ where
|
|||||||
|
|
||||||
/// A `Future` returned by [`dialer_select_proto_serial`] which negotiates
|
/// A `Future` returned by [`dialer_select_proto_serial`] which negotiates
|
||||||
/// a protocol iteratively by considering one protocol after the other.
|
/// a protocol iteratively by considering one protocol after the other.
|
||||||
|
#[pin_project::pin_project]
|
||||||
pub struct DialerSelectSeq<R, I>
|
pub struct DialerSelectSeq<R, I>
|
||||||
where
|
where
|
||||||
R: AsyncRead + AsyncWrite,
|
R: AsyncRead + AsyncWrite,
|
||||||
@ -155,83 +155,107 @@ where
|
|||||||
|
|
||||||
impl<R, I> Future for DialerSelectSeq<R, I>
|
impl<R, I> Future for DialerSelectSeq<R, I>
|
||||||
where
|
where
|
||||||
R: AsyncRead + AsyncWrite,
|
// The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
|
||||||
|
// It also makes the implementation considerably easier to write.
|
||||||
|
R: AsyncRead + AsyncWrite + Unpin,
|
||||||
I: Iterator,
|
I: Iterator,
|
||||||
I::Item: AsRef<[u8]>
|
I::Item: AsRef<[u8]>
|
||||||
{
|
{
|
||||||
type Item = (I::Item, Negotiated<R>);
|
type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
|
||||||
type Error = NegotiationError;
|
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||||
|
let this = self.project();
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
|
||||||
loop {
|
loop {
|
||||||
match mem::replace(&mut self.state, SeqState::Done) {
|
match mem::replace(this.state, SeqState::Done) {
|
||||||
SeqState::SendHeader { mut io } => {
|
SeqState::SendHeader { mut io } => {
|
||||||
if io.start_send(Message::Header(self.version))?.is_not_ready() {
|
match Pin::new(&mut io).poll_ready(cx)? {
|
||||||
self.state = SeqState::SendHeader { io };
|
Poll::Ready(()) => {},
|
||||||
return Ok(Async::NotReady)
|
Poll::Pending => {
|
||||||
|
*this.state = SeqState::SendHeader { io };
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
}
|
}
|
||||||
let protocol = self.protocols.next().ok_or(NegotiationError::Failed)?;
|
|
||||||
self.state = SeqState::SendProtocol { io, protocol };
|
if let Err(err) = Pin::new(&mut io).start_send(Message::Header(*this.version)) {
|
||||||
|
return Poll::Ready(Err(From::from(err)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
|
||||||
|
*this.state = SeqState::SendProtocol { io, protocol };
|
||||||
}
|
}
|
||||||
|
|
||||||
SeqState::SendProtocol { mut io, protocol } => {
|
SeqState::SendProtocol { mut io, protocol } => {
|
||||||
let p = Protocol::try_from(protocol.as_ref())?;
|
match Pin::new(&mut io).poll_ready(cx)? {
|
||||||
if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() {
|
Poll::Ready(()) => {},
|
||||||
self.state = SeqState::SendProtocol { io, protocol };
|
Poll::Pending => {
|
||||||
return Ok(Async::NotReady)
|
*this.state = SeqState::SendProtocol { io, protocol };
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
}
|
}
|
||||||
debug!("Dialer: Proposed protocol: {}", p);
|
|
||||||
if self.protocols.peek().is_some() {
|
let p = Protocol::try_from(protocol.as_ref())?;
|
||||||
self.state = SeqState::FlushProtocol { io, protocol }
|
if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
|
||||||
|
return Poll::Ready(Err(From::from(err)));
|
||||||
|
}
|
||||||
|
log::debug!("Dialer: Proposed protocol: {}", p);
|
||||||
|
|
||||||
|
if this.protocols.peek().is_some() {
|
||||||
|
*this.state = SeqState::FlushProtocol { io, protocol }
|
||||||
} else {
|
} else {
|
||||||
match self.version {
|
match this.version {
|
||||||
Version::V1 => self.state = SeqState::FlushProtocol { io, protocol },
|
Version::V1 => *this.state = SeqState::FlushProtocol { io, protocol },
|
||||||
Version::V1Lazy => {
|
Version::V1Lazy => {
|
||||||
debug!("Dialer: Expecting proposed protocol: {}", p);
|
log::debug!("Dialer: Expecting proposed protocol: {}", p);
|
||||||
let io = Negotiated::expecting(io.into_reader(), p, self.version);
|
let io = Negotiated::expecting(io.into_reader(), p, *this.version);
|
||||||
return Ok(Async::Ready((protocol, io)))
|
return Poll::Ready(Ok((protocol, io)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SeqState::FlushProtocol { mut io, protocol } => {
|
SeqState::FlushProtocol { mut io, protocol } => {
|
||||||
if io.poll_complete()?.is_not_ready() {
|
match Pin::new(&mut io).poll_flush(cx)? {
|
||||||
self.state = SeqState::FlushProtocol { io, protocol };
|
Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol },
|
||||||
return Ok(Async::NotReady)
|
Poll::Pending => {
|
||||||
|
*this.state = SeqState::FlushProtocol { io, protocol };
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
}
|
}
|
||||||
self.state = SeqState::AwaitProtocol { io, protocol }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
SeqState::AwaitProtocol { mut io, protocol } => {
|
SeqState::AwaitProtocol { mut io, protocol } => {
|
||||||
let msg = match io.poll()? {
|
let msg = match Pin::new(&mut io).poll_next(cx)? {
|
||||||
Async::NotReady => {
|
Poll::Ready(Some(msg)) => msg,
|
||||||
self.state = SeqState::AwaitProtocol { io, protocol };
|
Poll::Pending => {
|
||||||
return Ok(Async::NotReady)
|
*this.state = SeqState::AwaitProtocol { io, protocol };
|
||||||
|
return Poll::Pending
|
||||||
}
|
}
|
||||||
Async::Ready(None) =>
|
Poll::Ready(None) =>
|
||||||
return Err(NegotiationError::from(
|
return Poll::Ready(Err(NegotiationError::from(
|
||||||
io::Error::from(io::ErrorKind::UnexpectedEof))),
|
io::Error::from(io::ErrorKind::UnexpectedEof)))),
|
||||||
Async::Ready(Some(msg)) => msg,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
match msg {
|
match msg {
|
||||||
Message::Header(v) if v == self.version => {
|
Message::Header(v) if v == *this.version => {
|
||||||
self.state = SeqState::AwaitProtocol { io, protocol };
|
*this.state = SeqState::AwaitProtocol { io, protocol };
|
||||||
}
|
}
|
||||||
Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
|
Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
|
||||||
debug!("Dialer: Received confirmation for protocol: {}", p);
|
log::debug!("Dialer: Received confirmation for protocol: {}", p);
|
||||||
let (io, remaining) = io.into_inner();
|
let (io, remaining) = io.into_inner();
|
||||||
let io = Negotiated::completed(io, remaining);
|
let io = Negotiated::completed(io, remaining);
|
||||||
return Ok(Async::Ready((protocol, io)))
|
return Poll::Ready(Ok((protocol, io)));
|
||||||
}
|
}
|
||||||
Message::NotAvailable => {
|
Message::NotAvailable => {
|
||||||
debug!("Dialer: Received rejection of protocol: {}",
|
log::debug!("Dialer: Received rejection of protocol: {}",
|
||||||
String::from_utf8_lossy(protocol.as_ref()));
|
String::from_utf8_lossy(protocol.as_ref()));
|
||||||
let protocol = self.protocols.next()
|
let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
|
||||||
.ok_or(NegotiationError::Failed)?;
|
*this.state = SeqState::SendProtocol { io, protocol }
|
||||||
self.state = SeqState::SendProtocol { io, protocol }
|
|
||||||
}
|
}
|
||||||
_ => return Err(ProtocolError::InvalidMessage.into())
|
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SeqState::Done => panic!("SeqState::poll called after completion")
|
SeqState::Done => panic!("SeqState::poll called after completion")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -241,6 +265,7 @@ where
|
|||||||
/// A `Future` returned by [`dialer_select_proto_parallel`] which negotiates
|
/// A `Future` returned by [`dialer_select_proto_parallel`] which negotiates
|
||||||
/// a protocol selectively by considering all supported protocols of the remote
|
/// a protocol selectively by considering all supported protocols of the remote
|
||||||
/// "in parallel".
|
/// "in parallel".
|
||||||
|
#[pin_project::pin_project]
|
||||||
pub struct DialerSelectPar<R, I>
|
pub struct DialerSelectPar<R, I>
|
||||||
where
|
where
|
||||||
R: AsyncRead + AsyncWrite,
|
R: AsyncRead + AsyncWrite,
|
||||||
@ -267,76 +292,110 @@ where
|
|||||||
|
|
||||||
impl<R, I> Future for DialerSelectPar<R, I>
|
impl<R, I> Future for DialerSelectPar<R, I>
|
||||||
where
|
where
|
||||||
R: AsyncRead + AsyncWrite,
|
// The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
|
||||||
|
// It also makes the implementation considerably easier to write.
|
||||||
|
R: AsyncRead + AsyncWrite + Unpin,
|
||||||
I: Iterator,
|
I: Iterator,
|
||||||
I::Item: AsRef<[u8]>
|
I::Item: AsRef<[u8]>
|
||||||
{
|
{
|
||||||
type Item = (I::Item, Negotiated<R>);
|
type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
|
||||||
type Error = NegotiationError;
|
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||||
|
let this = self.project();
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
|
||||||
loop {
|
loop {
|
||||||
match mem::replace(&mut self.state, ParState::Done) {
|
match mem::replace(this.state, ParState::Done) {
|
||||||
ParState::SendHeader { mut io } => {
|
ParState::SendHeader { mut io } => {
|
||||||
if io.start_send(Message::Header(self.version))?.is_not_ready() {
|
match Pin::new(&mut io).poll_ready(cx)? {
|
||||||
self.state = ParState::SendHeader { io };
|
Poll::Ready(()) => {},
|
||||||
return Ok(Async::NotReady)
|
Poll::Pending => {
|
||||||
|
*this.state = ParState::SendHeader { io };
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
}
|
}
|
||||||
self.state = ParState::SendProtocolsRequest { io };
|
|
||||||
|
if let Err(err) = Pin::new(&mut io).start_send(Message::Header(*this.version)) {
|
||||||
|
return Poll::Ready(Err(From::from(err)));
|
||||||
|
}
|
||||||
|
|
||||||
|
*this.state = ParState::SendProtocolsRequest { io };
|
||||||
}
|
}
|
||||||
|
|
||||||
ParState::SendProtocolsRequest { mut io } => {
|
ParState::SendProtocolsRequest { mut io } => {
|
||||||
if io.start_send(Message::ListProtocols)?.is_not_ready() {
|
match Pin::new(&mut io).poll_ready(cx)? {
|
||||||
self.state = ParState::SendProtocolsRequest { io };
|
Poll::Ready(()) => {},
|
||||||
return Ok(Async::NotReady)
|
Poll::Pending => {
|
||||||
|
*this.state = ParState::SendProtocolsRequest { io };
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
}
|
}
|
||||||
debug!("Dialer: Requested supported protocols.");
|
|
||||||
self.state = ParState::Flush { io }
|
if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) {
|
||||||
|
return Poll::Ready(Err(From::from(err)));
|
||||||
|
}
|
||||||
|
|
||||||
|
log::debug!("Dialer: Requested supported protocols.");
|
||||||
|
*this.state = ParState::Flush { io }
|
||||||
}
|
}
|
||||||
|
|
||||||
ParState::Flush { mut io } => {
|
ParState::Flush { mut io } => {
|
||||||
if io.poll_complete()?.is_not_ready() {
|
match Pin::new(&mut io).poll_flush(cx)? {
|
||||||
self.state = ParState::Flush { io };
|
Poll::Ready(()) => *this.state = ParState::RecvProtocols { io },
|
||||||
return Ok(Async::NotReady)
|
Poll::Pending => {
|
||||||
|
*this.state = ParState::Flush { io };
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
}
|
}
|
||||||
self.state = ParState::RecvProtocols { io }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ParState::RecvProtocols { mut io } => {
|
ParState::RecvProtocols { mut io } => {
|
||||||
let msg = match io.poll()? {
|
let msg = match Pin::new(&mut io).poll_next(cx)? {
|
||||||
Async::NotReady => {
|
Poll::Ready(Some(msg)) => msg,
|
||||||
self.state = ParState::RecvProtocols { io };
|
Poll::Pending => {
|
||||||
return Ok(Async::NotReady)
|
*this.state = ParState::RecvProtocols { io };
|
||||||
|
return Poll::Pending
|
||||||
}
|
}
|
||||||
Async::Ready(None) =>
|
Poll::Ready(None) =>
|
||||||
return Err(NegotiationError::from(
|
return Poll::Ready(Err(NegotiationError::from(
|
||||||
io::Error::from(io::ErrorKind::UnexpectedEof))),
|
io::Error::from(io::ErrorKind::UnexpectedEof)))),
|
||||||
Async::Ready(Some(msg)) => msg,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
match &msg {
|
match &msg {
|
||||||
Message::Header(v) if v == &self.version => {
|
Message::Header(v) if v == this.version => {
|
||||||
self.state = ParState::RecvProtocols { io }
|
*this.state = ParState::RecvProtocols { io }
|
||||||
}
|
}
|
||||||
Message::Protocols(supported) => {
|
Message::Protocols(supported) => {
|
||||||
let protocol = self.protocols.by_ref()
|
let protocol = this.protocols.by_ref()
|
||||||
.find(|p| supported.iter().any(|s|
|
.find(|p| supported.iter().any(|s|
|
||||||
s.as_ref() == p.as_ref()))
|
s.as_ref() == p.as_ref()))
|
||||||
.ok_or(NegotiationError::Failed)?;
|
.ok_or(NegotiationError::Failed)?;
|
||||||
debug!("Dialer: Found supported protocol: {}",
|
log::debug!("Dialer: Found supported protocol: {}",
|
||||||
String::from_utf8_lossy(protocol.as_ref()));
|
String::from_utf8_lossy(protocol.as_ref()));
|
||||||
self.state = ParState::SendProtocol { io, protocol };
|
*this.state = ParState::SendProtocol { io, protocol };
|
||||||
}
|
}
|
||||||
_ => return Err(ProtocolError::InvalidMessage.into())
|
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ParState::SendProtocol { mut io, protocol } => {
|
ParState::SendProtocol { mut io, protocol } => {
|
||||||
let p = Protocol::try_from(protocol.as_ref())?;
|
match Pin::new(&mut io).poll_ready(cx)? {
|
||||||
if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() {
|
Poll::Ready(()) => {},
|
||||||
self.state = ParState::SendProtocol { io, protocol };
|
Poll::Pending => {
|
||||||
return Ok(Async::NotReady)
|
*this.state = ParState::SendProtocol { io, protocol };
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
}
|
}
|
||||||
debug!("Dialer: Expecting proposed protocol: {}", p);
|
|
||||||
let io = Negotiated::expecting(io.into_reader(), p, self.version);
|
let p = Protocol::try_from(protocol.as_ref())?;
|
||||||
return Ok(Async::Ready((protocol, io)))
|
if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
|
||||||
|
return Poll::Ready(Err(From::from(err)));
|
||||||
|
}
|
||||||
|
log::debug!("Dialer: Expecting proposed protocol: {}", p);
|
||||||
|
|
||||||
|
let io = Negotiated::expecting(io.into_reader(), p, *this.version);
|
||||||
|
return Poll::Ready(Ok((protocol, io)))
|
||||||
}
|
}
|
||||||
|
|
||||||
ParState::Done => panic!("ParState::poll called after completion")
|
ParState::Done => panic!("ParState::poll called after completion")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,11 +18,9 @@
|
|||||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
// DEALINGS IN THE SOFTWARE.
|
// DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
use bytes::{Bytes, BytesMut, Buf, BufMut};
|
use bytes::{Bytes, BytesMut, Buf as _, BufMut as _};
|
||||||
use futures::{try_ready, Async, Poll, Sink, StartSend, Stream, AsyncSink};
|
use futures::{prelude::*, io::IoSlice};
|
||||||
use std::{io, u16};
|
use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16};
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
|
||||||
use unsigned_varint as uvi;
|
|
||||||
|
|
||||||
const MAX_LEN_BYTES: u16 = 2;
|
const MAX_LEN_BYTES: u16 = 2;
|
||||||
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
|
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
|
||||||
@ -34,9 +32,11 @@ const DEFAULT_BUFFER_SIZE: usize = 64;
|
|||||||
/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
|
/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
|
||||||
/// frame length). Frames mostly consist in a short protocol name, which is highly
|
/// frame length). Frames mostly consist in a short protocol name, which is highly
|
||||||
/// unlikely to be more than 16KiB long.
|
/// unlikely to be more than 16KiB long.
|
||||||
|
#[pin_project::pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct LengthDelimited<R> {
|
pub struct LengthDelimited<R> {
|
||||||
/// The inner I/O resource.
|
/// The inner I/O resource.
|
||||||
|
#[pin]
|
||||||
inner: R,
|
inner: R,
|
||||||
/// Read buffer for a single incoming unsigned-varint length-delimited frame.
|
/// Read buffer for a single incoming unsigned-varint length-delimited frame.
|
||||||
read_buffer: BytesMut,
|
read_buffer: BytesMut,
|
||||||
@ -76,20 +76,7 @@ impl<R> LengthDelimited<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a reference to the underlying I/O stream.
|
/// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream
|
||||||
pub fn inner_ref(&self) -> &R {
|
|
||||||
&self.inner
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a mutable reference to the underlying I/O stream.
|
|
||||||
///
|
|
||||||
/// > **Note**: Care should be taken to not tamper with the underlying stream of data
|
|
||||||
/// > coming in, as it may corrupt the stream of frames.
|
|
||||||
pub fn inner_mut(&mut self) -> &mut R {
|
|
||||||
&mut self.inner
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Drops the `LengthDelimited` resource, yielding the underlying I/O stream
|
|
||||||
/// together with the remaining write buffer containing the uvi-framed data
|
/// together with the remaining write buffer containing the uvi-framed data
|
||||||
/// that has not yet been written to the underlying I/O stream.
|
/// that has not yet been written to the underlying I/O stream.
|
||||||
///
|
///
|
||||||
@ -107,7 +94,7 @@ impl<R> LengthDelimited<R> {
|
|||||||
(self.inner, self.write_buffer)
|
(self.inner, self.write_buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts the `LengthDelimited` into a `LengthDelimitedReader`, dropping the
|
/// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the
|
||||||
/// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying
|
/// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying
|
||||||
/// I/O stream.
|
/// I/O stream.
|
||||||
///
|
///
|
||||||
@ -121,25 +108,29 @@ impl<R> LengthDelimited<R> {
|
|||||||
/// Writes all buffered frame data to the underlying I/O stream,
|
/// Writes all buffered frame data to the underlying I/O stream,
|
||||||
/// _without flushing it_.
|
/// _without flushing it_.
|
||||||
///
|
///
|
||||||
/// After this method returns `Async::Ready`, the write buffer of frames
|
/// After this method returns `Poll::Ready`, the write buffer of frames
|
||||||
/// submitted to the `Sink` is guaranteed to be empty.
|
/// submitted to the `Sink` is guaranteed to be empty.
|
||||||
pub fn poll_write_buffer(&mut self) -> Poll<(), io::Error>
|
pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context)
|
||||||
|
-> Poll<Result<(), io::Error>>
|
||||||
where
|
where
|
||||||
R: AsyncWrite
|
R: AsyncWrite
|
||||||
{
|
{
|
||||||
while !self.write_buffer.is_empty() {
|
let mut this = self.project();
|
||||||
let n = try_ready!(self.inner.poll_write(&self.write_buffer));
|
|
||||||
|
|
||||||
if n == 0 {
|
while !this.write_buffer.is_empty() {
|
||||||
return Err(io::Error::new(
|
match this.inner.as_mut().poll_write(cx, &this.write_buffer) {
|
||||||
io::ErrorKind::WriteZero,
|
Poll::Pending => return Poll::Pending,
|
||||||
"Failed to write buffered frame."))
|
Poll::Ready(Ok(0)) => {
|
||||||
|
return Poll::Ready(Err(io::Error::new(
|
||||||
|
io::ErrorKind::WriteZero,
|
||||||
|
"Failed to write buffered frame.")))
|
||||||
|
}
|
||||||
|
Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
|
||||||
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.write_buffer.advance(n);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Async::Ready(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -147,72 +138,67 @@ impl<R> Stream for LengthDelimited<R>
|
|||||||
where
|
where
|
||||||
R: AsyncRead
|
R: AsyncRead
|
||||||
{
|
{
|
||||||
type Item = Bytes;
|
type Item = Result<Bytes, io::Error>;
|
||||||
type Error = io::Error;
|
|
||||||
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
|
let mut this = self.project();
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
|
||||||
loop {
|
loop {
|
||||||
match &mut self.read_state {
|
match this.read_state {
|
||||||
ReadState::ReadLength { buf, pos } => {
|
ReadState::ReadLength { buf, pos } => {
|
||||||
match self.inner.read(&mut buf[*pos .. *pos + 1]) {
|
match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) {
|
||||||
Ok(0) => {
|
Poll::Ready(Ok(0)) => {
|
||||||
if *pos == 0 {
|
if *pos == 0 {
|
||||||
return Ok(Async::Ready(None));
|
return Poll::Ready(None);
|
||||||
} else {
|
} else {
|
||||||
return Err(io::ErrorKind::UnexpectedEof.into());
|
return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(n) => {
|
Poll::Ready(Ok(n)) => {
|
||||||
debug_assert_eq!(n, 1);
|
debug_assert_eq!(n, 1);
|
||||||
*pos += n;
|
*pos += n;
|
||||||
}
|
}
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
|
||||||
return Ok(Async::NotReady);
|
Poll::Pending => return Poll::Pending,
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if (buf[*pos - 1] & 0x80) == 0 {
|
if (buf[*pos - 1] & 0x80) == 0 {
|
||||||
// MSB is not set, indicating the end of the length prefix.
|
// MSB is not set, indicating the end of the length prefix.
|
||||||
let (len, _) = uvi::decode::u16(buf).map_err(|e| {
|
let (len, _) = unsigned_varint::decode::u16(buf)
|
||||||
log::debug!("invalid length prefix: {}", e);
|
.map_err(|e| {
|
||||||
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
|
log::debug!("invalid length prefix: {}", e);
|
||||||
})?;
|
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
|
||||||
|
})?;
|
||||||
|
|
||||||
if len >= 1 {
|
if len >= 1 {
|
||||||
self.read_state = ReadState::ReadData { len, pos: 0 };
|
*this.read_state = ReadState::ReadData { len, pos: 0 };
|
||||||
self.read_buffer.resize(len as usize, 0);
|
this.read_buffer.resize(len as usize, 0);
|
||||||
} else {
|
} else {
|
||||||
debug_assert_eq!(len, 0);
|
debug_assert_eq!(len, 0);
|
||||||
self.read_state = ReadState::default();
|
*this.read_state = ReadState::default();
|
||||||
return Ok(Async::Ready(Some(Bytes::new())));
|
return Poll::Ready(Some(Ok(Bytes::new())));
|
||||||
}
|
}
|
||||||
} else if *pos == MAX_LEN_BYTES as usize {
|
} else if *pos == MAX_LEN_BYTES as usize {
|
||||||
// MSB signals more length bytes but we have already read the maximum.
|
// MSB signals more length bytes but we have already read the maximum.
|
||||||
// See the module documentation about the max frame len.
|
// See the module documentation about the max frame len.
|
||||||
return Err(io::Error::new(
|
return Poll::Ready(Some(Err(io::Error::new(
|
||||||
io::ErrorKind::InvalidData,
|
io::ErrorKind::InvalidData,
|
||||||
"Maximum frame length exceeded"));
|
"Maximum frame length exceeded"))));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ReadState::ReadData { len, pos } => {
|
ReadState::ReadData { len, pos } => {
|
||||||
match self.inner.read(&mut self.read_buffer[*pos..]) {
|
match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
|
||||||
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
|
Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
|
||||||
Ok(n) => *pos += n,
|
Poll::Ready(Ok(n)) => *pos += n,
|
||||||
Err(err) =>
|
Poll::Pending => return Poll::Pending,
|
||||||
if err.kind() == io::ErrorKind::WouldBlock {
|
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
|
||||||
return Ok(Async::NotReady)
|
|
||||||
} else {
|
|
||||||
return Err(err)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if *pos == *len as usize {
|
if *pos == *len as usize {
|
||||||
// Finished reading the frame.
|
// Finished reading the frame.
|
||||||
let frame = self.read_buffer.split_off(0).freeze();
|
let frame = this.read_buffer.split_off(0).freeze();
|
||||||
self.read_state = ReadState::default();
|
*this.read_state = ReadState::default();
|
||||||
return Ok(Async::Ready(Some(frame)));
|
return Poll::Ready(Some(Ok(frame)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -220,58 +206,87 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> Sink for LengthDelimited<R>
|
impl<R> Sink<Bytes> for LengthDelimited<R>
|
||||||
where
|
where
|
||||||
R: AsyncWrite,
|
R: AsyncWrite,
|
||||||
{
|
{
|
||||||
type SinkItem = Bytes;
|
type Error = io::Error;
|
||||||
type SinkError = io::Error;
|
|
||||||
|
|
||||||
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
// Use the maximum frame length also as a (soft) upper limit
|
// Use the maximum frame length also as a (soft) upper limit
|
||||||
// for the entire write buffer. The actual (hard) limit is thus
|
// for the entire write buffer. The actual (hard) limit is thus
|
||||||
// implied to be roughly 2 * MAX_FRAME_SIZE.
|
// implied to be roughly 2 * MAX_FRAME_SIZE.
|
||||||
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
|
if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
|
||||||
self.poll_complete()?;
|
match self.as_mut().poll_write_buffer(cx) {
|
||||||
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
|
Poll::Ready(Ok(())) => {},
|
||||||
return Ok(AsyncSink::NotReady(msg))
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
debug_assert!(self.as_mut().project().write_buffer.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
let len = msg.len() as u16;
|
Poll::Ready(Ok(()))
|
||||||
if len > MAX_FRAME_SIZE {
|
|
||||||
return Err(io::Error::new(
|
|
||||||
io::ErrorKind::InvalidData,
|
|
||||||
"Maximum frame size exceeded."))
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut uvi_buf = uvi::encode::u16_buffer();
|
|
||||||
let uvi_len = uvi::encode::u16(len, &mut uvi_buf);
|
|
||||||
self.write_buffer.reserve(len as usize + uvi_len.len());
|
|
||||||
self.write_buffer.put(uvi_len);
|
|
||||||
self.write_buffer.put(msg);
|
|
||||||
|
|
||||||
Ok(AsyncSink::Ready)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
|
||||||
|
let this = self.project();
|
||||||
|
|
||||||
|
let len = match u16::try_from(item.len()) {
|
||||||
|
Ok(len) if len <= MAX_FRAME_SIZE => len,
|
||||||
|
_ => {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::InvalidData,
|
||||||
|
"Maximum frame size exceeded."))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut uvi_buf = unsigned_varint::encode::u16_buffer();
|
||||||
|
let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
|
||||||
|
this.write_buffer.reserve(len as usize + uvi_len.len());
|
||||||
|
this.write_buffer.put(uvi_len);
|
||||||
|
this.write_buffer.put(item);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
// Write all buffered frame data to the underlying I/O stream.
|
// Write all buffered frame data to the underlying I/O stream.
|
||||||
try_ready!(self.poll_write_buffer());
|
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
|
||||||
|
Poll::Ready(Ok(())) => {},
|
||||||
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
|
||||||
|
let this = self.project();
|
||||||
|
debug_assert!(this.write_buffer.is_empty());
|
||||||
|
|
||||||
// Flush the underlying I/O stream.
|
// Flush the underlying I/O stream.
|
||||||
try_ready!(self.inner.poll_flush());
|
this.inner.poll_flush(cx)
|
||||||
return Ok(Async::Ready(()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
try_ready!(self.poll_complete());
|
// Write all buffered frame data to the underlying I/O stream.
|
||||||
Ok(self.inner.shutdown()?)
|
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
|
||||||
|
Poll::Ready(Ok(())) => {},
|
||||||
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
|
||||||
|
let this = self.project();
|
||||||
|
debug_assert!(this.write_buffer.is_empty());
|
||||||
|
|
||||||
|
// Close the underlying I/O stream.
|
||||||
|
this.inner.poll_close(cx)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited
|
/// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited
|
||||||
/// frames on an underlying I/O resource combined with direct `AsyncWrite` access.
|
/// frames on an underlying I/O resource combined with direct `AsyncWrite` access.
|
||||||
|
#[pin_project::pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct LengthDelimitedReader<R> {
|
pub struct LengthDelimitedReader<R> {
|
||||||
|
#[pin]
|
||||||
inner: LengthDelimited<R>
|
inner: LengthDelimited<R>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,58 +299,23 @@ impl<R> LengthDelimitedReader<R> {
|
|||||||
/// # Panic
|
/// # Panic
|
||||||
///
|
///
|
||||||
/// Will panic if called while there is data in the read or write buffer.
|
/// Will panic if called while there is data in the read or write buffer.
|
||||||
/// The read buffer is guaranteed to be empty whenever [`Stream::poll`] yields
|
/// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`]
|
||||||
/// a new `Message`. The write buffer is guaranteed to be empty whenever
|
/// yield a new `Message`. The write buffer is guaranteed to be empty whenever
|
||||||
/// [`LengthDelimited::poll_write_buffer`] yields [`Async::Ready`] or after
|
/// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after
|
||||||
/// the [`Sink`] has been completely flushed via [`Sink::poll_complete`].
|
/// the [`Sink`] has been completely flushed via [`Sink::poll_flush`].
|
||||||
pub fn into_inner(self) -> (R, BytesMut) {
|
pub fn into_inner(self) -> (R, BytesMut) {
|
||||||
self.inner.into_inner()
|
self.inner.into_inner()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a reference to the underlying I/O stream.
|
|
||||||
pub fn inner_ref(&self) -> &R {
|
|
||||||
self.inner.inner_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a mutable reference to the underlying I/O stream.
|
|
||||||
///
|
|
||||||
/// > **Note**: Care should be taken to not tamper with the underlying stream of data
|
|
||||||
/// > coming in, as it may corrupt the stream of frames.
|
|
||||||
pub fn inner_mut(&mut self) -> &mut R {
|
|
||||||
self.inner.inner_mut()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> Stream for LengthDelimitedReader<R>
|
impl<R> Stream for LengthDelimitedReader<R>
|
||||||
where
|
where
|
||||||
R: AsyncRead
|
R: AsyncRead
|
||||||
{
|
{
|
||||||
type Item = Bytes;
|
type Item = Result<Bytes, io::Error>;
|
||||||
type Error = io::Error;
|
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
self.inner.poll()
|
self.project().inner.poll_next(cx)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R> io::Write for LengthDelimitedReader<R>
|
|
||||||
where
|
|
||||||
R: AsyncWrite
|
|
||||||
{
|
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
|
||||||
while !self.inner.write_buffer.is_empty() {
|
|
||||||
if self.inner.poll_write_buffer()?.is_not_ready() {
|
|
||||||
return Err(io::ErrorKind::WouldBlock.into())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
self.inner_mut().write(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
|
||||||
match self.inner.poll_complete()? {
|
|
||||||
Async::Ready(()) => Ok(()),
|
|
||||||
Async::NotReady => Err(io::ErrorKind::WouldBlock.into())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -343,23 +323,62 @@ impl<R> AsyncWrite for LengthDelimitedReader<R>
|
|||||||
where
|
where
|
||||||
R: AsyncWrite
|
R: AsyncWrite
|
||||||
{
|
{
|
||||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8])
|
||||||
try_ready!(self.inner.poll_complete());
|
-> Poll<Result<usize, io::Error>>
|
||||||
self.inner_mut().shutdown()
|
{
|
||||||
|
// `this` here designates the `LengthDelimited`.
|
||||||
|
let mut this = self.project().inner;
|
||||||
|
|
||||||
|
// We need to flush any data previously written with the `LengthDelimited`.
|
||||||
|
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
|
||||||
|
Poll::Ready(Ok(())) => {},
|
||||||
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
debug_assert!(this.write_buffer.is_empty());
|
||||||
|
|
||||||
|
this.project().inner.poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
|
||||||
|
self.project().inner.poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
|
||||||
|
self.project().inner.poll_close(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice])
|
||||||
|
-> Poll<Result<usize, io::Error>>
|
||||||
|
{
|
||||||
|
// `this` here designates the `LengthDelimited`.
|
||||||
|
let mut this = self.project().inner;
|
||||||
|
|
||||||
|
// We need to flush any data previously written with the `LengthDelimited`.
|
||||||
|
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
|
||||||
|
Poll::Ready(Ok(())) => {},
|
||||||
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||||
|
Poll::Pending => return Poll::Pending,
|
||||||
|
}
|
||||||
|
debug_assert!(this.write_buffer.is_empty());
|
||||||
|
|
||||||
|
this.project().inner.poll_write_vectored(cx, bufs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use futures::{Future, Stream};
|
|
||||||
use crate::length_delimited::LengthDelimited;
|
use crate::length_delimited::LengthDelimited;
|
||||||
use std::io::{Cursor, ErrorKind};
|
use async_std::net::{TcpListener, TcpStream};
|
||||||
|
use futures::{prelude::*, io::Cursor};
|
||||||
|
use quickcheck::*;
|
||||||
|
use std::io::ErrorKind;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn basic_read() {
|
fn basic_read() {
|
||||||
let data = vec![6, 9, 8, 7, 6, 5, 4];
|
let data = vec![6, 9, 8, 7, 6, 5, 4];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data));
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait().unwrap();
|
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
|
||||||
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
|
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -367,7 +386,7 @@ mod tests {
|
|||||||
fn basic_read_two() {
|
fn basic_read_two() {
|
||||||
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
|
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data));
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait().unwrap();
|
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
|
||||||
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
|
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -378,13 +397,10 @@ mod tests {
|
|||||||
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
|
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
|
||||||
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
|
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
|
||||||
data.extend(frame.clone().into_iter());
|
data.extend(frame.clone().into_iter());
|
||||||
let framed = LengthDelimited::new(Cursor::new(data));
|
let mut framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed
|
let recved = futures::executor::block_on(async move {
|
||||||
.into_future()
|
framed.next().await
|
||||||
.map(|(m, _)| m)
|
}).unwrap();
|
||||||
.map_err(|_| ())
|
|
||||||
.wait()
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(recved.unwrap(), frame);
|
assert_eq!(recved.unwrap(), frame);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -392,12 +408,10 @@ mod tests {
|
|||||||
fn packet_len_too_long() {
|
fn packet_len_too_long() {
|
||||||
let mut data = vec![0x81, 0x81, 0x1];
|
let mut data = vec![0x81, 0x81, 0x1];
|
||||||
data.extend((0..16513).map(|_| 0));
|
data.extend((0..16513).map(|_| 0));
|
||||||
let framed = LengthDelimited::new(Cursor::new(data));
|
let mut framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed
|
let recved = futures::executor::block_on(async move {
|
||||||
.into_future()
|
framed.next().await.unwrap()
|
||||||
.map(|(m, _)| m)
|
});
|
||||||
.map_err(|(err, _)| err)
|
|
||||||
.wait();
|
|
||||||
|
|
||||||
if let Err(io_err) = recved {
|
if let Err(io_err) = recved {
|
||||||
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
|
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
|
||||||
@ -410,7 +424,7 @@ mod tests {
|
|||||||
fn empty_frames() {
|
fn empty_frames() {
|
||||||
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
|
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data));
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait().unwrap();
|
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
recved,
|
recved,
|
||||||
vec![
|
vec![
|
||||||
@ -427,7 +441,7 @@ mod tests {
|
|||||||
fn unexpected_eof_in_len() {
|
fn unexpected_eof_in_len() {
|
||||||
let data = vec![0x89];
|
let data = vec![0x89];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data));
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait();
|
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
|
||||||
if let Err(io_err) = recved {
|
if let Err(io_err) = recved {
|
||||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||||
} else {
|
} else {
|
||||||
@ -439,7 +453,7 @@ mod tests {
|
|||||||
fn unexpected_eof_in_data() {
|
fn unexpected_eof_in_data() {
|
||||||
let data = vec![5];
|
let data = vec![5];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data));
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait();
|
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
|
||||||
if let Err(io_err) = recved {
|
if let Err(io_err) = recved {
|
||||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||||
} else {
|
} else {
|
||||||
@ -451,12 +465,54 @@ mod tests {
|
|||||||
fn unexpected_eof_in_data2() {
|
fn unexpected_eof_in_data2() {
|
||||||
let data = vec![5, 9, 8, 7];
|
let data = vec![5, 9, 8, 7];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data));
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait();
|
let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
|
||||||
if let Err(io_err) = recved {
|
if let Err(io_err) = recved {
|
||||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||||
} else {
|
} else {
|
||||||
panic!()
|
panic!()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn writing_reading() {
|
||||||
|
fn prop(frames: Vec<Vec<u8>>) -> TestResult {
|
||||||
|
async_std::task::block_on(async move {
|
||||||
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
|
let listener_addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
|
let expected_frames = frames.clone();
|
||||||
|
let server = async_std::task::spawn(async move {
|
||||||
|
let socket = listener.accept().await.unwrap().0;
|
||||||
|
let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
|
||||||
|
|
||||||
|
let mut buf = vec![0u8; 0];
|
||||||
|
for expected in expected_frames {
|
||||||
|
if expected.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if buf.len() < expected.len() {
|
||||||
|
buf.resize(expected.len(), 0);
|
||||||
|
}
|
||||||
|
let n = connec.read(&mut buf).await.unwrap();
|
||||||
|
assert_eq!(&buf[..n], &expected[..]);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let client = async_std::task::spawn(async move {
|
||||||
|
let socket = TcpStream::connect(&listener_addr).await.unwrap();
|
||||||
|
let mut connec = LengthDelimited::new(socket);
|
||||||
|
for frame in frames {
|
||||||
|
connec.send(From::from(frame)).await.unwrap();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
server.await;
|
||||||
|
client.await;
|
||||||
|
});
|
||||||
|
|
||||||
|
TestResult::passed()
|
||||||
|
}
|
||||||
|
|
||||||
|
quickcheck(prop as fn(_) -> _)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -77,26 +77,19 @@
|
|||||||
//!
|
//!
|
||||||
//! ```no_run
|
//! ```no_run
|
||||||
//! # fn main() {
|
//! # fn main() {
|
||||||
//! use bytes::Bytes;
|
//! use async_std::net::TcpStream;
|
||||||
//! use multistream_select::{dialer_select_proto, Version};
|
//! use multistream_select::{dialer_select_proto, Version};
|
||||||
//! use futures::{Future, Sink, Stream};
|
//! use futures::prelude::*;
|
||||||
//! use tokio_tcp::TcpStream;
|
|
||||||
//! use tokio::runtime::current_thread::Runtime;
|
|
||||||
//!
|
//!
|
||||||
//! #[derive(Debug, Copy, Clone)]
|
//! async_std::task::block_on(async move {
|
||||||
//! enum MyProto { Echo, Hello }
|
//! let socket = TcpStream::connect("127.0.0.1:10333").await.unwrap();
|
||||||
//!
|
//!
|
||||||
//! let client = TcpStream::connect(&"127.0.0.1:10333".parse().unwrap())
|
//! let protos = vec![b"/echo/1.0.0", b"/echo/2.5.0"];
|
||||||
//! .from_err()
|
//! let (protocol, _io) = dialer_select_proto(socket, protos, Version::V1).await.unwrap();
|
||||||
//! .and_then(move |io| {
|
|
||||||
//! let protos = vec![b"/echo/1.0.0", b"/echo/2.5.0"];
|
|
||||||
//! dialer_select_proto(io, protos, Version::V1)
|
|
||||||
//! })
|
|
||||||
//! .map(|(protocol, _io)| protocol);
|
|
||||||
//!
|
//!
|
||||||
//! let mut rt = Runtime::new().unwrap();
|
//! println!("Negotiated protocol: {:?}", protocol);
|
||||||
//! let protocol = rt.block_on(client).expect("failed to find a protocol");
|
//! // You can now use `_io` to communicate with the remote.
|
||||||
//! println!("Negotiated protocol: {:?}", protocol);
|
//! });
|
||||||
//! # }
|
//! # }
|
||||||
//! ```
|
//! ```
|
||||||
//!
|
//!
|
||||||
|
@ -21,13 +21,12 @@
|
|||||||
//! Protocol negotiation strategies for the peer acting as the listener
|
//! Protocol negotiation strategies for the peer acting as the listener
|
||||||
//! in a multistream-select protocol negotiation.
|
//! in a multistream-select protocol negotiation.
|
||||||
|
|
||||||
use futures::prelude::*;
|
|
||||||
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
|
|
||||||
use log::{debug, warn};
|
|
||||||
use smallvec::SmallVec;
|
|
||||||
use std::{io, iter::FromIterator, mem, convert::TryFrom};
|
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
|
||||||
use crate::{Negotiated, NegotiationError};
|
use crate::{Negotiated, NegotiationError};
|
||||||
|
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
|
||||||
|
|
||||||
|
use futures::prelude::*;
|
||||||
|
use smallvec::SmallVec;
|
||||||
|
use std::{convert::TryFrom as _, io, iter::FromIterator, mem, pin::Pin, task::{Context, Poll}};
|
||||||
|
|
||||||
/// Returns a `Future` that negotiates a protocol on the given I/O stream
|
/// Returns a `Future` that negotiates a protocol on the given I/O stream
|
||||||
/// for a peer acting as the _listener_ (or _responder_).
|
/// for a peer acting as the _listener_ (or _responder_).
|
||||||
@ -49,7 +48,7 @@ where
|
|||||||
match Protocol::try_from(n.as_ref()) {
|
match Protocol::try_from(n.as_ref()) {
|
||||||
Ok(p) => Some((n, p)),
|
Ok(p) => Some((n, p)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Listener: Ignoring invalid protocol: {} due to {}",
|
log::warn!("Listener: Ignoring invalid protocol: {} due to {}",
|
||||||
String::from_utf8_lossy(n.as_ref()), e);
|
String::from_utf8_lossy(n.as_ref()), e);
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
@ -64,6 +63,7 @@ where
|
|||||||
|
|
||||||
/// The `Future` returned by [`listener_select_proto`] that performs a
|
/// The `Future` returned by [`listener_select_proto`] that performs a
|
||||||
/// multistream-select protocol negotiation on an underlying I/O stream.
|
/// multistream-select protocol negotiation on an underlying I/O stream.
|
||||||
|
#[pin_project::pin_project]
|
||||||
pub struct ListenerSelectFuture<R, N>
|
pub struct ListenerSelectFuture<R, N>
|
||||||
where
|
where
|
||||||
R: AsyncRead + AsyncWrite,
|
R: AsyncRead + AsyncWrite,
|
||||||
@ -94,64 +94,80 @@ where
|
|||||||
|
|
||||||
impl<R, N> Future for ListenerSelectFuture<R, N>
|
impl<R, N> Future for ListenerSelectFuture<R, N>
|
||||||
where
|
where
|
||||||
R: AsyncRead + AsyncWrite,
|
// The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
|
||||||
|
// It also makes the implementation considerably easier to write.
|
||||||
|
R: AsyncRead + AsyncWrite + Unpin,
|
||||||
N: AsRef<[u8]> + Clone
|
N: AsRef<[u8]> + Clone
|
||||||
{
|
{
|
||||||
type Item = (N, Negotiated<R>);
|
type Output = Result<(N, Negotiated<R>), NegotiationError>;
|
||||||
type Error = NegotiationError;
|
|
||||||
|
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||||
|
let this = self.project();
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
|
||||||
loop {
|
loop {
|
||||||
match mem::replace(&mut self.state, State::Done) {
|
match mem::replace(this.state, State::Done) {
|
||||||
State::RecvHeader { mut io } => {
|
State::RecvHeader { mut io } => {
|
||||||
match io.poll()? {
|
match io.poll_next_unpin(cx) {
|
||||||
Async::Ready(Some(Message::Header(version))) => {
|
Poll::Ready(Some(Ok(Message::Header(version)))) => {
|
||||||
self.state = State::SendHeader { io, version }
|
*this.state = State::SendHeader { io, version }
|
||||||
}
|
}
|
||||||
Async::Ready(Some(_)) => {
|
Poll::Ready(Some(Ok(_))) => {
|
||||||
return Err(ProtocolError::InvalidMessage.into())
|
return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
|
||||||
}
|
},
|
||||||
Async::Ready(None) =>
|
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
|
||||||
return Err(NegotiationError::from(
|
Poll::Ready(None) =>
|
||||||
|
return Poll::Ready(Err(NegotiationError::from(
|
||||||
ProtocolError::IoError(
|
ProtocolError::IoError(
|
||||||
io::ErrorKind::UnexpectedEof.into()))),
|
io::ErrorKind::UnexpectedEof.into())))),
|
||||||
Async::NotReady => {
|
Poll::Pending => {
|
||||||
self.state = State::RecvHeader { io };
|
*this.state = State::RecvHeader { io };
|
||||||
return Ok(Async::NotReady)
|
return Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
State::SendHeader { mut io, version } => {
|
State::SendHeader { mut io, version } => {
|
||||||
if io.start_send(Message::Header(version))?.is_not_ready() {
|
match Pin::new(&mut io).poll_ready(cx) {
|
||||||
return Ok(Async::NotReady)
|
Poll::Pending => {
|
||||||
|
*this.state = State::SendHeader { io, version };
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
|
Poll::Ready(Ok(())) => {},
|
||||||
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||||
}
|
}
|
||||||
self.state = match version {
|
|
||||||
|
if let Err(err) = Pin::new(&mut io).start_send(Message::Header(version)) {
|
||||||
|
return Poll::Ready(Err(From::from(err)));
|
||||||
|
}
|
||||||
|
|
||||||
|
*this.state = match version {
|
||||||
Version::V1 => State::Flush { io },
|
Version::V1 => State::Flush { io },
|
||||||
Version::V1Lazy => State::RecvMessage { io },
|
Version::V1Lazy => State::RecvMessage { io },
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
State::RecvMessage { mut io } => {
|
State::RecvMessage { mut io } => {
|
||||||
let msg = match io.poll() {
|
let msg = match Pin::new(&mut io).poll_next(cx) {
|
||||||
Ok(Async::Ready(Some(msg))) => msg,
|
Poll::Ready(Some(Ok(msg))) => msg,
|
||||||
Ok(Async::Ready(None)) =>
|
Poll::Ready(None) =>
|
||||||
return Err(NegotiationError::from(
|
return Poll::Ready(Err(NegotiationError::from(
|
||||||
ProtocolError::IoError(
|
ProtocolError::IoError(
|
||||||
io::ErrorKind::UnexpectedEof.into()))),
|
io::ErrorKind::UnexpectedEof.into())))),
|
||||||
Ok(Async::NotReady) => {
|
Poll::Pending => {
|
||||||
self.state = State::RecvMessage { io };
|
*this.state = State::RecvMessage { io };
|
||||||
return Ok(Async::NotReady)
|
return Poll::Pending;
|
||||||
}
|
}
|
||||||
Err(e) => return Err(e.into())
|
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
|
||||||
};
|
};
|
||||||
|
|
||||||
match msg {
|
match msg {
|
||||||
Message::ListProtocols => {
|
Message::ListProtocols => {
|
||||||
let supported = self.protocols.iter().map(|(_,p)| p).cloned().collect();
|
let supported = this.protocols.iter().map(|(_,p)| p).cloned().collect();
|
||||||
let message = Message::Protocols(supported);
|
let message = Message::Protocols(supported);
|
||||||
self.state = State::SendMessage { io, message, protocol: None }
|
*this.state = State::SendMessage { io, message, protocol: None }
|
||||||
}
|
}
|
||||||
Message::Protocol(p) => {
|
Message::Protocol(p) => {
|
||||||
let protocol = self.protocols.iter().find_map(|(name, proto)| {
|
let protocol = this.protocols.iter().find_map(|(name, proto)| {
|
||||||
if &p == proto {
|
if &p == proto {
|
||||||
Some(name.clone())
|
Some(name.clone())
|
||||||
} else {
|
} else {
|
||||||
@ -160,45 +176,60 @@ where
|
|||||||
});
|
});
|
||||||
|
|
||||||
let message = if protocol.is_some() {
|
let message = if protocol.is_some() {
|
||||||
debug!("Listener: confirming protocol: {}", p);
|
log::debug!("Listener: confirming protocol: {}", p);
|
||||||
Message::Protocol(p.clone())
|
Message::Protocol(p.clone())
|
||||||
} else {
|
} else {
|
||||||
debug!("Listener: rejecting protocol: {}",
|
log::debug!("Listener: rejecting protocol: {}",
|
||||||
String::from_utf8_lossy(p.as_ref()));
|
String::from_utf8_lossy(p.as_ref()));
|
||||||
Message::NotAvailable
|
Message::NotAvailable
|
||||||
};
|
};
|
||||||
|
|
||||||
self.state = State::SendMessage { io, message, protocol };
|
*this.state = State::SendMessage { io, message, protocol };
|
||||||
}
|
}
|
||||||
_ => return Err(ProtocolError::InvalidMessage.into())
|
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
State::SendMessage { mut io, message, protocol } => {
|
State::SendMessage { mut io, message, protocol } => {
|
||||||
if let AsyncSink::NotReady(message) = io.start_send(message)? {
|
match Pin::new(&mut io).poll_ready(cx) {
|
||||||
self.state = State::SendMessage { io, message, protocol };
|
Poll::Pending => {
|
||||||
return Ok(Async::NotReady)
|
*this.state = State::SendMessage { io, message, protocol };
|
||||||
};
|
return Poll::Pending
|
||||||
|
},
|
||||||
|
Poll::Ready(Ok(())) => {},
|
||||||
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(err) = Pin::new(&mut io).start_send(message) {
|
||||||
|
return Poll::Ready(Err(From::from(err)));
|
||||||
|
}
|
||||||
|
|
||||||
// If a protocol has been selected, finish negotiation.
|
// If a protocol has been selected, finish negotiation.
|
||||||
// Otherwise flush the sink and expect to receive another
|
// Otherwise flush the sink and expect to receive another
|
||||||
// message.
|
// message.
|
||||||
self.state = match protocol {
|
*this.state = match protocol {
|
||||||
Some(protocol) => {
|
Some(protocol) => {
|
||||||
debug!("Listener: sent confirmed protocol: {}",
|
log::debug!("Listener: sent confirmed protocol: {}",
|
||||||
String::from_utf8_lossy(protocol.as_ref()));
|
String::from_utf8_lossy(protocol.as_ref()));
|
||||||
let (io, remaining) = io.into_inner();
|
let (io, remaining) = io.into_inner();
|
||||||
let io = Negotiated::completed(io, remaining);
|
let io = Negotiated::completed(io, remaining);
|
||||||
return Ok(Async::Ready((protocol, io)))
|
return Poll::Ready(Ok((protocol, io)));
|
||||||
}
|
}
|
||||||
None => State::Flush { io }
|
None => State::Flush { io }
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
State::Flush { mut io } => {
|
State::Flush { mut io } => {
|
||||||
if io.poll_complete()?.is_not_ready() {
|
match Pin::new(&mut io).poll_flush(cx) {
|
||||||
self.state = State::Flush { io };
|
Poll::Pending => {
|
||||||
return Ok(Async::NotReady)
|
*this.state = State::Flush { io };
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
|
Poll::Ready(Ok(())) => *this.state = State::RecvMessage { io },
|
||||||
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||||
}
|
}
|
||||||
self.state = State::RecvMessage { io }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
State::Done => panic!("State::poll called after completion")
|
State::Done => panic!("State::poll called after completion")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,12 +18,12 @@
|
|||||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
// DEALINGS IN THE SOFTWARE.
|
// DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
use bytes::{BytesMut, Buf};
|
|
||||||
use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError};
|
use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError};
|
||||||
use futures::{prelude::*, Async, try_ready};
|
|
||||||
use log::debug;
|
use bytes::{BytesMut, Buf};
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
|
||||||
use std::{mem, io, fmt, error::Error};
|
use pin_project::{pin_project, project};
|
||||||
|
use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
|
||||||
|
|
||||||
/// An I/O stream that has settled on an (application-layer) protocol to use.
|
/// An I/O stream that has settled on an (application-layer) protocol to use.
|
||||||
///
|
///
|
||||||
@ -36,28 +36,40 @@ use std::{mem, io, fmt, error::Error};
|
|||||||
///
|
///
|
||||||
/// Reading from a `Negotiated` I/O stream that still has pending negotiation
|
/// Reading from a `Negotiated` I/O stream that still has pending negotiation
|
||||||
/// protocol data to send implicitly triggers flushing of all yet unsent data.
|
/// protocol data to send implicitly triggers flushing of all yet unsent data.
|
||||||
|
#[pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Negotiated<TInner> {
|
pub struct Negotiated<TInner> {
|
||||||
|
#[pin]
|
||||||
state: State<TInner>
|
state: State<TInner>
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A `Future` that waits on the completion of protocol negotiation.
|
/// A `Future` that waits on the completion of protocol negotiation.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct NegotiatedComplete<TInner> {
|
pub struct NegotiatedComplete<TInner> {
|
||||||
inner: Option<Negotiated<TInner>>
|
inner: Option<Negotiated<TInner>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TInner: AsyncRead + AsyncWrite> Future for NegotiatedComplete<TInner> {
|
impl<TInner> Future for NegotiatedComplete<TInner>
|
||||||
type Item = Negotiated<TInner>;
|
where
|
||||||
type Error = NegotiationError;
|
// `Unpin` is required not because of implementation details but because we produce the
|
||||||
|
// `Negotiated` as the output of the future.
|
||||||
|
TInner: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
type Output = Result<Negotiated<TInner>, NegotiationError>;
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
||||||
let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
|
let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
|
||||||
if io.poll()?.is_not_ready() {
|
match Negotiated::poll(Pin::new(&mut io), cx) {
|
||||||
self.inner = Some(io);
|
Poll::Pending => {
|
||||||
return Ok(Async::NotReady)
|
self.inner = Some(io);
|
||||||
|
return Poll::Pending
|
||||||
|
},
|
||||||
|
Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
|
||||||
|
Poll::Ready(Err(err)) => {
|
||||||
|
self.inner = Some(io);
|
||||||
|
return Poll::Ready(Err(err));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return Ok(Async::Ready(io))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,66 +87,67 @@ impl<TInner> Negotiated<TInner> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Polls the `Negotiated` for completion.
|
/// Polls the `Negotiated` for completion.
|
||||||
fn poll(&mut self) -> Poll<(), NegotiationError>
|
#[project]
|
||||||
|
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), NegotiationError>>
|
||||||
where
|
where
|
||||||
TInner: AsyncRead + AsyncWrite
|
TInner: AsyncRead + AsyncWrite + Unpin
|
||||||
{
|
{
|
||||||
// Flush any pending negotiation data.
|
// Flush any pending negotiation data.
|
||||||
match self.poll_flush() {
|
match self.as_mut().poll_flush(cx) {
|
||||||
Ok(Async::Ready(())) => {},
|
Poll::Ready(Ok(())) => {},
|
||||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
Poll::Pending => return Poll::Pending,
|
||||||
Err(e) => {
|
Poll::Ready(Err(e)) => {
|
||||||
// If the remote closed the stream, it is important to still
|
// If the remote closed the stream, it is important to still
|
||||||
// continue reading the data that was sent, if any.
|
// continue reading the data that was sent, if any.
|
||||||
if e.kind() != io::ErrorKind::WriteZero {
|
if e.kind() != io::ErrorKind::WriteZero {
|
||||||
return Err(e.into())
|
return Poll::Ready(Err(e.into()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let State::Completed { remaining, .. } = &mut self.state {
|
let mut this = self.project();
|
||||||
let _ = remaining.split_to(remaining.len()); // Drop remaining data flushed above.
|
|
||||||
return Ok(Async::Ready(()))
|
#[project]
|
||||||
|
match this.state.as_mut().project() {
|
||||||
|
State::Completed { remaining, .. } => {
|
||||||
|
debug_assert!(remaining.is_empty());
|
||||||
|
return Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read outstanding protocol negotiation messages.
|
// Read outstanding protocol negotiation messages.
|
||||||
loop {
|
loop {
|
||||||
match mem::replace(&mut self.state, State::Invalid) {
|
match mem::replace(&mut *this.state, State::Invalid) {
|
||||||
State::Expecting { mut io, protocol, version } => {
|
State::Expecting { mut io, protocol, version } => {
|
||||||
let msg = match io.poll() {
|
let msg = match Pin::new(&mut io).poll_next(cx)? {
|
||||||
Ok(Async::Ready(Some(msg))) => msg,
|
Poll::Ready(Some(msg)) => msg,
|
||||||
Ok(Async::NotReady) => {
|
Poll::Pending => {
|
||||||
self.state = State::Expecting { io, protocol, version };
|
*this.state = State::Expecting { io, protocol, version };
|
||||||
return Ok(Async::NotReady)
|
return Poll::Pending
|
||||||
}
|
},
|
||||||
Ok(Async::Ready(None)) => {
|
Poll::Ready(None) => {
|
||||||
self.state = State::Expecting { io, protocol, version };
|
return Poll::Ready(Err(ProtocolError::IoError(
|
||||||
return Err(ProtocolError::IoError(
|
io::ErrorKind::UnexpectedEof.into()).into()));
|
||||||
io::ErrorKind::UnexpectedEof.into()).into())
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
self.state = State::Expecting { io, protocol, version };
|
|
||||||
return Err(err.into())
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Message::Header(v) = &msg {
|
if let Message::Header(v) = &msg {
|
||||||
if v == &version {
|
if *v == version {
|
||||||
self.state = State::Expecting { io, protocol, version };
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Message::Protocol(p) = &msg {
|
if let Message::Protocol(p) = &msg {
|
||||||
if p.as_ref() == protocol.as_ref() {
|
if p.as_ref() == protocol.as_ref() {
|
||||||
debug!("Negotiated: Received confirmation for protocol: {}", p);
|
log::debug!("Negotiated: Received confirmation for protocol: {}", p);
|
||||||
let (io, remaining) = io.into_inner();
|
let (io, remaining) = io.into_inner();
|
||||||
self.state = State::Completed { io, remaining };
|
*this.state = State::Completed { io, remaining };
|
||||||
return Ok(Async::Ready(()))
|
return Poll::Ready(Ok(()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return Err(NegotiationError::Failed)
|
return Poll::Ready(Err(NegotiationError::Failed));
|
||||||
}
|
}
|
||||||
|
|
||||||
_ => panic!("Negotiated: Invalid state")
|
_ => panic!("Negotiated: Invalid state")
|
||||||
@ -142,7 +155,7 @@ impl<TInner> Negotiated<TInner> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a `NegotiatedComplete` future that waits for protocol
|
/// Returns a [`NegotiatedComplete`] future that waits for protocol
|
||||||
/// negotiation to complete.
|
/// negotiation to complete.
|
||||||
pub fn complete(self) -> NegotiatedComplete<TInner> {
|
pub fn complete(self) -> NegotiatedComplete<TInner> {
|
||||||
NegotiatedComplete { inner: Some(self) }
|
NegotiatedComplete { inner: Some(self) }
|
||||||
@ -150,12 +163,14 @@ impl<TInner> Negotiated<TInner> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// The states of a `Negotiated` I/O stream.
|
/// The states of a `Negotiated` I/O stream.
|
||||||
|
#[pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum State<R> {
|
enum State<R> {
|
||||||
/// In this state, a `Negotiated` is still expecting to
|
/// In this state, a `Negotiated` is still expecting to
|
||||||
/// receive confirmation of the protocol it as settled on.
|
/// receive confirmation of the protocol it as settled on.
|
||||||
Expecting {
|
Expecting {
|
||||||
/// The underlying I/O stream.
|
/// The underlying I/O stream.
|
||||||
|
#[pin]
|
||||||
io: MessageReader<R>,
|
io: MessageReader<R>,
|
||||||
/// The expected protocol (i.e. name and version).
|
/// The expected protocol (i.e. name and version).
|
||||||
protocol: Protocol,
|
protocol: Protocol,
|
||||||
@ -167,113 +182,157 @@ enum State<R> {
|
|||||||
/// only be pending the sending of the final acknowledgement,
|
/// only be pending the sending of the final acknowledgement,
|
||||||
/// which is prepended to / combined with the next write for
|
/// which is prepended to / combined with the next write for
|
||||||
/// efficiency.
|
/// efficiency.
|
||||||
Completed { io: R, remaining: BytesMut },
|
Completed { #[pin] io: R, remaining: BytesMut },
|
||||||
|
|
||||||
/// Temporary state while moving the `io` resource from
|
/// Temporary state while moving the `io` resource from
|
||||||
/// `Expecting` to `Completed`.
|
/// `Expecting` to `Completed`.
|
||||||
Invalid,
|
Invalid,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> io::Read for Negotiated<R>
|
impl<TInner> AsyncRead for Negotiated<TInner>
|
||||||
where
|
where
|
||||||
R: AsyncRead + AsyncWrite
|
TInner: AsyncRead + AsyncWrite + Unpin
|
||||||
{
|
{
|
||||||
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
|
#[project]
|
||||||
|
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8])
|
||||||
|
-> Poll<Result<usize, io::Error>>
|
||||||
|
{
|
||||||
loop {
|
loop {
|
||||||
if let State::Completed { io, remaining } = &mut self.state {
|
#[project]
|
||||||
// If protocol negotiation is complete and there is no
|
match self.as_mut().project().state.project() {
|
||||||
// remaining data to be flushed, commence with reading.
|
State::Completed { io, remaining } => {
|
||||||
if remaining.is_empty() {
|
// If protocol negotiation is complete and there is no
|
||||||
return io.read(buf)
|
// remaining data to be flushed, commence with reading.
|
||||||
}
|
if remaining.is_empty() {
|
||||||
|
return io.poll_read(cx, buf)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Poll the `Negotiated`, driving protocol negotiation to completion,
|
// Poll the `Negotiated`, driving protocol negotiation to completion,
|
||||||
// including flushing of any remaining data.
|
// including flushing of any remaining data.
|
||||||
let result = self.poll();
|
match self.as_mut().poll(cx) {
|
||||||
|
Poll::Ready(Ok(())) => {},
|
||||||
// There is still remaining data to be sent before data relating
|
Poll::Pending => return Poll::Pending,
|
||||||
// to the negotiated protocol can be read.
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||||
if let Ok(Async::NotReady) = result {
|
|
||||||
return Err(io::ErrorKind::WouldBlock.into())
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(err) = result {
|
|
||||||
return Err(err.into())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<TInner> AsyncRead for Negotiated<TInner>
|
// TODO: implement once method is stabilized in the futures crate
|
||||||
where
|
/*unsafe fn initializer(&self) -> Initializer {
|
||||||
TInner: AsyncRead + AsyncWrite
|
|
||||||
{
|
|
||||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
|
||||||
match &self.state {
|
match &self.state {
|
||||||
State::Completed { io, .. } =>
|
State::Completed { io, .. } => io.initializer(),
|
||||||
io.prepare_uninitialized_buffer(buf),
|
State::Expecting { io, .. } => io.inner_ref().initializer(),
|
||||||
State::Expecting { io, .. } =>
|
State::Invalid => panic!("Negotiated: Invalid state"),
|
||||||
io.inner_ref().prepare_uninitialized_buffer(buf),
|
|
||||||
State::Invalid => panic!("Negotiated: Invalid state")
|
|
||||||
}
|
}
|
||||||
}
|
}*/
|
||||||
}
|
|
||||||
|
|
||||||
impl<TInner> io::Write for Negotiated<TInner>
|
#[project]
|
||||||
where
|
fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context, bufs: &mut [IoSliceMut])
|
||||||
TInner: AsyncWrite
|
-> Poll<Result<usize, io::Error>>
|
||||||
{
|
{
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
loop {
|
||||||
match &mut self.state {
|
#[project]
|
||||||
State::Completed { io, ref mut remaining } => {
|
match self.as_mut().project().state.project() {
|
||||||
while !remaining.is_empty() {
|
State::Completed { io, remaining } => {
|
||||||
let n = io.write(&remaining)?;
|
// If protocol negotiation is complete and there is no
|
||||||
if n == 0 {
|
// remaining data to be flushed, commence with reading.
|
||||||
return Err(io::ErrorKind::WriteZero.into())
|
if remaining.is_empty() {
|
||||||
|
return io.poll_read_vectored(cx, bufs)
|
||||||
}
|
}
|
||||||
remaining.advance(n);
|
},
|
||||||
}
|
_ => {}
|
||||||
io.write(buf)
|
}
|
||||||
},
|
|
||||||
State::Expecting { io, .. } => io.write(buf),
|
|
||||||
State::Invalid => panic!("Negotiated: Invalid state")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
// Poll the `Negotiated`, driving protocol negotiation to completion,
|
||||||
match &mut self.state {
|
// including flushing of any remaining data.
|
||||||
State::Completed { io, ref mut remaining } => {
|
match self.as_mut().poll(cx) {
|
||||||
while !remaining.is_empty() {
|
Poll::Ready(Ok(())) => {},
|
||||||
let n = io.write(remaining)?;
|
Poll::Pending => return Poll::Pending,
|
||||||
if n == 0 {
|
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||||
return Err(io::Error::new(
|
}
|
||||||
io::ErrorKind::WriteZero,
|
|
||||||
"Failed to write remaining buffer."))
|
|
||||||
}
|
|
||||||
remaining.advance(n);
|
|
||||||
}
|
|
||||||
io.flush()
|
|
||||||
},
|
|
||||||
State::Expecting { io, .. } => io.flush(),
|
|
||||||
State::Invalid => panic!("Negotiated: Invalid state")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TInner> AsyncWrite for Negotiated<TInner>
|
impl<TInner> AsyncWrite for Negotiated<TInner>
|
||||||
where
|
where
|
||||||
TInner: AsyncWrite + AsyncRead
|
TInner: AsyncWrite + AsyncRead + Unpin
|
||||||
{
|
{
|
||||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
#[project]
|
||||||
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
||||||
|
#[project]
|
||||||
|
match self.project().state.project() {
|
||||||
|
State::Completed { mut io, remaining } => {
|
||||||
|
while !remaining.is_empty() {
|
||||||
|
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
|
||||||
|
if n == 0 {
|
||||||
|
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
|
||||||
|
}
|
||||||
|
remaining.advance(n);
|
||||||
|
}
|
||||||
|
io.poll_write(cx, buf)
|
||||||
|
},
|
||||||
|
State::Expecting { io, .. } => io.poll_write(cx, buf),
|
||||||
|
State::Invalid => panic!("Negotiated: Invalid state"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[project]
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
|
||||||
|
#[project]
|
||||||
|
match self.project().state.project() {
|
||||||
|
State::Completed { mut io, remaining } => {
|
||||||
|
while !remaining.is_empty() {
|
||||||
|
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
|
||||||
|
if n == 0 {
|
||||||
|
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
|
||||||
|
}
|
||||||
|
remaining.advance(n);
|
||||||
|
}
|
||||||
|
io.poll_flush(cx)
|
||||||
|
},
|
||||||
|
State::Expecting { io, .. } => io.poll_flush(cx),
|
||||||
|
State::Invalid => panic!("Negotiated: Invalid state"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[project]
|
||||||
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
|
||||||
// Ensure all data has been flushed and expected negotiation messages
|
// Ensure all data has been flushed and expected negotiation messages
|
||||||
// have been received.
|
// have been received.
|
||||||
try_ready!(self.poll().map_err(Into::<io::Error>::into));
|
ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
|
||||||
|
ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
|
||||||
|
|
||||||
// Continue with the shutdown of the underlying I/O stream.
|
// Continue with the shutdown of the underlying I/O stream.
|
||||||
match &mut self.state {
|
#[project]
|
||||||
State::Completed { io, .. } => io.shutdown(),
|
match self.project().state.project() {
|
||||||
State::Expecting { io, .. } => io.shutdown(),
|
State::Completed { io, .. } => io.poll_close(cx),
|
||||||
State::Invalid => panic!("Negotiated: Invalid state")
|
State::Expecting { io, .. } => io.poll_close(cx),
|
||||||
|
State::Invalid => panic!("Negotiated: Invalid state"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[project]
|
||||||
|
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice])
|
||||||
|
-> Poll<Result<usize, io::Error>>
|
||||||
|
{
|
||||||
|
#[project]
|
||||||
|
match self.project().state.project() {
|
||||||
|
State::Completed { mut io, remaining } => {
|
||||||
|
while !remaining.is_empty() {
|
||||||
|
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
|
||||||
|
if n == 0 {
|
||||||
|
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
|
||||||
|
}
|
||||||
|
remaining.advance(n);
|
||||||
|
}
|
||||||
|
io.poll_write_vectored(cx, bufs)
|
||||||
|
},
|
||||||
|
State::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
|
||||||
|
State::Invalid => panic!("Negotiated: Invalid state"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -300,12 +359,12 @@ impl From<io::Error> for NegotiationError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Into<io::Error> for NegotiationError {
|
impl From<NegotiationError> for io::Error {
|
||||||
fn into(self) -> io::Error {
|
fn from(err: NegotiationError) -> io::Error {
|
||||||
if let NegotiationError::ProtocolError(e) = self {
|
if let NegotiationError::ProtocolError(e) = err {
|
||||||
return e.into()
|
return e.into()
|
||||||
}
|
}
|
||||||
io::Error::new(io::ErrorKind::Other, self)
|
io::Error::new(io::ErrorKind::Other, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -333,27 +392,33 @@ impl fmt::Display for NegotiationError {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use quickcheck::*;
|
use quickcheck::*;
|
||||||
use std::io::Write;
|
use std::{io::Write, task::Poll};
|
||||||
|
|
||||||
/// An I/O resource with a fixed write capacity (total and per write op).
|
/// An I/O resource with a fixed write capacity (total and per write op).
|
||||||
struct Capped { buf: Vec<u8>, step: usize }
|
struct Capped { buf: Vec<u8>, step: usize }
|
||||||
|
|
||||||
impl io::Write for Capped {
|
impl AsyncRead for Capped {
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
fn poll_read(self: Pin<&mut Self>, _: &mut Context, _: &mut [u8]) -> Poll<Result<usize, io::Error>> {
|
||||||
if self.buf.len() + buf.len() > self.buf.capacity() {
|
unreachable!()
|
||||||
return Err(io::ErrorKind::WriteZero.into())
|
|
||||||
}
|
|
||||||
self.buf.write(&buf[.. usize::min(self.step, buf.len())])
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AsyncWrite for Capped {
|
impl AsyncWrite for Capped {
|
||||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
fn poll_write(mut self: Pin<&mut Self>, _: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
||||||
Ok(().into())
|
if self.buf.len() + buf.len() > self.buf.capacity() {
|
||||||
|
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
|
||||||
|
}
|
||||||
|
let len = usize::min(self.step, buf.len());
|
||||||
|
let n = Write::write(&mut self.buf, &buf[.. len]).unwrap();
|
||||||
|
Poll::Ready(Ok(n))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), io::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), io::Error>> {
|
||||||
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -369,7 +434,7 @@ mod tests {
|
|||||||
loop {
|
loop {
|
||||||
// Write until `new` has been fully written or the capped buffer runs
|
// Write until `new` has been fully written or the capped buffer runs
|
||||||
// over capacity and yields WriteZero.
|
// over capacity and yields WriteZero.
|
||||||
match io.write(&new[written..]) {
|
match future::poll_fn(|cx| Pin::new(&mut io).poll_write(cx, &new[written..])).now_or_never().unwrap() {
|
||||||
Ok(n) =>
|
Ok(n) =>
|
||||||
if let State::Completed { remaining, .. } = &io.state {
|
if let State::Completed { remaining, .. } = &io.state {
|
||||||
assert!(remaining.is_empty());
|
assert!(remaining.is_empty());
|
||||||
@ -388,7 +453,7 @@ mod tests {
|
|||||||
return TestResult::failed()
|
return TestResult::failed()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => panic!("Unexpected error: {:?}", e)
|
Err(e) => panic!("Unexpected error: {:?}", e),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,12 +25,11 @@
|
|||||||
//! `Stream` and `Sink` implementations of `MessageIO` and
|
//! `Stream` and `Sink` implementations of `MessageIO` and
|
||||||
//! `MessageReader`.
|
//! `MessageReader`.
|
||||||
|
|
||||||
use bytes::{Bytes, BytesMut, BufMut};
|
|
||||||
use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
|
use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
|
||||||
use futures::{prelude::*, try_ready};
|
|
||||||
use log::trace;
|
use bytes::{Bytes, BytesMut, BufMut};
|
||||||
use std::{io, fmt, error::Error, convert::TryFrom};
|
use futures::{prelude::*, io::IoSlice, ready};
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}};
|
||||||
use unsigned_varint as uvi;
|
use unsigned_varint as uvi;
|
||||||
|
|
||||||
/// The maximum number of supported protocols that can be processed.
|
/// The maximum number of supported protocols that can be processed.
|
||||||
@ -264,7 +263,9 @@ impl Message {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s.
|
/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s.
|
||||||
|
#[pin_project::pin_project]
|
||||||
pub struct MessageIO<R> {
|
pub struct MessageIO<R> {
|
||||||
|
#[pin]
|
||||||
inner: LengthDelimited<R>,
|
inner: LengthDelimited<R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -277,8 +278,8 @@ impl<R> MessageIO<R> {
|
|||||||
Self { inner: LengthDelimited::new(inner) }
|
Self { inner: LengthDelimited::new(inner) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts the `MessageIO` into a `MessageReader`, dropping the
|
/// Converts the [`MessageIO`] into a [`MessageReader`], dropping the
|
||||||
/// `Message`-oriented `Sink` in favour of direct `AsyncWrite` access
|
/// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access
|
||||||
/// to the underlying I/O stream.
|
/// to the underlying I/O stream.
|
||||||
///
|
///
|
||||||
/// This is typically done if further negotiation messages are expected to be
|
/// This is typically done if further negotiation messages are expected to be
|
||||||
@ -288,7 +289,7 @@ impl<R> MessageIO<R> {
|
|||||||
MessageReader { inner: self.inner.into_reader() }
|
MessageReader { inner: self.inner.into_reader() }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Drops the `MessageIO` resource, yielding the underlying I/O stream
|
/// Drops the [`MessageIO`] resource, yielding the underlying I/O stream
|
||||||
/// together with the remaining write buffer containing the protocol
|
/// together with the remaining write buffer containing the protocol
|
||||||
/// negotiation frame data that has not yet been written to the I/O stream.
|
/// negotiation frame data that has not yet been written to the I/O stream.
|
||||||
///
|
///
|
||||||
@ -309,28 +310,28 @@ impl<R> MessageIO<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> Sink for MessageIO<R>
|
impl<R> Sink<Message> for MessageIO<R>
|
||||||
where
|
where
|
||||||
R: AsyncWrite,
|
R: AsyncWrite,
|
||||||
{
|
{
|
||||||
type SinkItem = Message;
|
type Error = ProtocolError;
|
||||||
type SinkError = ProtocolError;
|
|
||||||
|
|
||||||
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
|
self.project().inner.poll_ready(cx).map_err(From::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||||
let mut buf = BytesMut::new();
|
let mut buf = BytesMut::new();
|
||||||
msg.encode(&mut buf)?;
|
item.encode(&mut buf)?;
|
||||||
match self.inner.start_send(buf.freeze())? {
|
self.project().inner.start_send(buf.freeze()).map_err(From::from)
|
||||||
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(msg)),
|
|
||||||
AsyncSink::Ready => Ok(AsyncSink::Ready),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
Ok(self.inner.poll_complete()?)
|
self.project().inner.poll_flush(cx).map_err(From::from)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
||||||
Ok(self.inner.close()?)
|
self.project().inner.poll_close(cx).map_err(From::from)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -338,18 +339,24 @@ impl<R> Stream for MessageIO<R>
|
|||||||
where
|
where
|
||||||
R: AsyncRead
|
R: AsyncRead
|
||||||
{
|
{
|
||||||
type Item = Message;
|
type Item = Result<Message, ProtocolError>;
|
||||||
type Error = ProtocolError;
|
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
poll_stream(&mut self.inner)
|
match poll_stream(self.project().inner, cx) {
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
Poll::Ready(None) => Poll::Ready(None),
|
||||||
|
Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
|
||||||
|
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(From::from(err)))),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A `MessageReader` implements a `Stream` of `Message`s on an underlying
|
/// A `MessageReader` implements a `Stream` of `Message`s on an underlying
|
||||||
/// I/O resource combined with direct `AsyncWrite` access.
|
/// I/O resource combined with direct `AsyncWrite` access.
|
||||||
|
#[pin_project::pin_project]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct MessageReader<R> {
|
pub struct MessageReader<R> {
|
||||||
|
#[pin]
|
||||||
inner: LengthDelimitedReader<R>
|
inner: LengthDelimitedReader<R>
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -373,35 +380,16 @@ impl<R> MessageReader<R> {
|
|||||||
pub fn into_inner(self) -> (R, BytesMut) {
|
pub fn into_inner(self) -> (R, BytesMut) {
|
||||||
self.inner.into_inner()
|
self.inner.into_inner()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a reference to the underlying I/O stream.
|
|
||||||
pub fn inner_ref(&self) -> &R {
|
|
||||||
self.inner.inner_ref()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R> Stream for MessageReader<R>
|
impl<R> Stream for MessageReader<R>
|
||||||
where
|
where
|
||||||
R: AsyncRead
|
R: AsyncRead
|
||||||
{
|
{
|
||||||
type Item = Message;
|
type Item = Result<Message, ProtocolError>;
|
||||||
type Error = ProtocolError;
|
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
poll_stream(&mut self.inner)
|
poll_stream(self.project().inner, cx)
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<R> io::Write for MessageReader<R>
|
|
||||||
where
|
|
||||||
R: AsyncWrite
|
|
||||||
{
|
|
||||||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
|
||||||
self.inner.write(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn flush(&mut self) -> io::Result<()> {
|
|
||||||
self.inner.flush()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -409,24 +397,39 @@ impl<TInner> AsyncWrite for MessageReader<TInner>
|
|||||||
where
|
where
|
||||||
TInner: AsyncWrite
|
TInner: AsyncWrite
|
||||||
{
|
{
|
||||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
||||||
self.inner.shutdown()
|
self.project().inner.poll_write(cx, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
|
||||||
|
self.project().inner.poll_flush(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
|
||||||
|
self.project().inner.poll_close(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) -> Poll<Result<usize, io::Error>> {
|
||||||
|
self.project().inner.poll_write_vectored(cx, bufs)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_stream<S>(stream: &mut S) -> Poll<Option<Message>, ProtocolError>
|
fn poll_stream<S>(stream: Pin<&mut S>, cx: &mut Context) -> Poll<Option<Result<Message, ProtocolError>>>
|
||||||
where
|
where
|
||||||
S: Stream<Item = Bytes, Error = io::Error>,
|
S: Stream<Item = Result<Bytes, io::Error>>,
|
||||||
{
|
{
|
||||||
let msg = if let Some(msg) = try_ready!(stream.poll()) {
|
let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
|
||||||
Message::decode(msg)?
|
match Message::decode(msg) {
|
||||||
|
Ok(m) => m,
|
||||||
|
Err(err) => return Poll::Ready(Some(Err(err))),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return Ok(Async::Ready(None))
|
return Poll::Ready(None)
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!("Received message: {:?}", msg);
|
log::trace!("Received message: {:?}", msg);
|
||||||
|
|
||||||
Ok(Async::Ready(Some(msg)))
|
Poll::Ready(Some(Ok(msg)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A protocol error.
|
/// A protocol error.
|
||||||
|
@ -25,164 +25,156 @@
|
|||||||
use crate::{Version, NegotiationError};
|
use crate::{Version, NegotiationError};
|
||||||
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
|
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
|
||||||
use crate::{dialer_select_proto, listener_select_proto};
|
use crate::{dialer_select_proto, listener_select_proto};
|
||||||
|
|
||||||
|
use async_std::net::{TcpListener, TcpStream};
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use tokio::runtime::current_thread::Runtime;
|
|
||||||
use tokio_tcp::{TcpListener, TcpStream};
|
|
||||||
use tokio_io::io as nio;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn select_proto_basic() {
|
fn select_proto_basic() {
|
||||||
fn run(version: Version) {
|
async fn run(version: Version) {
|
||||||
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let listener_addr = listener.local_addr().unwrap();
|
let listener_addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
let server = listener
|
let server = async_std::task::spawn(async move {
|
||||||
.incoming()
|
let connec = listener.accept().await.unwrap().0;
|
||||||
.into_future()
|
let protos = vec![b"/proto1", b"/proto2"];
|
||||||
.map(|s| s.0.unwrap())
|
let (proto, mut io) = listener_select_proto(connec, protos).await.unwrap();
|
||||||
.map_err(|(e, _)| e.into())
|
assert_eq!(proto, b"/proto2");
|
||||||
.and_then(move |connec| {
|
|
||||||
let protos = vec![b"/proto1", b"/proto2"];
|
|
||||||
listener_select_proto(connec, protos)
|
|
||||||
})
|
|
||||||
.and_then(|(proto, io)| {
|
|
||||||
nio::write_all(io, b"pong").from_err().map(move |_| proto)
|
|
||||||
});
|
|
||||||
|
|
||||||
let client = TcpStream::connect(&listener_addr)
|
let mut out = vec![0; 32];
|
||||||
.from_err()
|
let n = io.read(&mut out).await.unwrap();
|
||||||
.and_then(move |connec| {
|
out.truncate(n);
|
||||||
let protos = vec![b"/proto3", b"/proto2"];
|
assert_eq!(out, b"ping");
|
||||||
dialer_select_proto(connec, protos, version)
|
|
||||||
})
|
|
||||||
.and_then(|(proto, io)| {
|
|
||||||
nio::write_all(io, b"ping").from_err().map(move |(io, _)| (proto, io))
|
|
||||||
})
|
|
||||||
.and_then(|(proto, io)| {
|
|
||||||
nio::read_exact(io, [0; 4]).from_err().map(move |(_, msg)| {
|
|
||||||
assert_eq!(&msg, b"pong");
|
|
||||||
proto
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
let mut rt = Runtime::new().unwrap();
|
io.write_all(b"pong").await.unwrap();
|
||||||
let (dialer_chosen, listener_chosen) =
|
io.flush().await.unwrap();
|
||||||
rt.block_on(client.join(server)).unwrap();
|
});
|
||||||
|
|
||||||
assert_eq!(dialer_chosen, b"/proto2");
|
let client = async_std::task::spawn(async move {
|
||||||
assert_eq!(listener_chosen, b"/proto2");
|
let connec = TcpStream::connect(&listener_addr).await.unwrap();
|
||||||
|
let protos = vec![b"/proto3", b"/proto2"];
|
||||||
|
let (proto, mut io) = dialer_select_proto(connec, protos.into_iter(), version)
|
||||||
|
.await.unwrap();
|
||||||
|
assert_eq!(proto, b"/proto2");
|
||||||
|
|
||||||
|
io.write_all(b"ping").await.unwrap();
|
||||||
|
io.flush().await.unwrap();
|
||||||
|
|
||||||
|
let mut out = vec![0; 32];
|
||||||
|
let n = io.read(&mut out).await.unwrap();
|
||||||
|
out.truncate(n);
|
||||||
|
assert_eq!(out, b"pong");
|
||||||
|
});
|
||||||
|
|
||||||
|
server.await;
|
||||||
|
client.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
run(Version::V1);
|
async_std::task::block_on(run(Version::V1));
|
||||||
run(Version::V1Lazy);
|
async_std::task::block_on(run(Version::V1Lazy));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn no_protocol_found() {
|
fn no_protocol_found() {
|
||||||
fn run(version: Version) {
|
async fn run(version: Version) {
|
||||||
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let listener_addr = listener.local_addr().unwrap();
|
let listener_addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
let server = listener
|
let server = async_std::task::spawn(async move {
|
||||||
.incoming()
|
let connec = listener.accept().await.unwrap().0;
|
||||||
.into_future()
|
let protos = vec![b"/proto1", b"/proto2"];
|
||||||
.map(|s| s.0.unwrap())
|
let io = match listener_select_proto(connec, protos).await {
|
||||||
.map_err(|(e, _)| e.into())
|
Ok((_, io)) => io,
|
||||||
.and_then(move |connec| {
|
// We don't explicitly check for `Failed` because the client might close the connection when it
|
||||||
let protos = vec![b"/proto1", b"/proto2"];
|
// realizes that we have no protocol in common.
|
||||||
listener_select_proto(connec, protos)
|
Err(_) => return,
|
||||||
})
|
};
|
||||||
.and_then(|(proto, io)| io.complete().map(move |_| proto));
|
match io.complete().await {
|
||||||
|
Err(NegotiationError::Failed) => {},
|
||||||
|
_ => panic!(),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let client = TcpStream::connect(&listener_addr)
|
let client = async_std::task::spawn(async move {
|
||||||
.from_err()
|
let connec = TcpStream::connect(&listener_addr).await.unwrap();
|
||||||
.and_then(move |connec| {
|
let protos = vec![b"/proto3", b"/proto4"];
|
||||||
let protos = vec![b"/proto3", b"/proto4"];
|
let io = match dialer_select_proto(connec, protos.into_iter(), version).await {
|
||||||
dialer_select_proto(connec, protos, version)
|
Err(NegotiationError::Failed) => return,
|
||||||
})
|
Ok((_, io)) => io,
|
||||||
.and_then(|(proto, io)| io.complete().map(move |_| proto));
|
Err(_) => panic!()
|
||||||
|
};
|
||||||
|
match io.complete().await {
|
||||||
|
Err(NegotiationError::Failed) => {},
|
||||||
|
_ => panic!(),
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let mut rt = Runtime::new().unwrap();
|
server.await;
|
||||||
match rt.block_on(client.join(server)) {
|
client.await;
|
||||||
Err(NegotiationError::Failed) => (),
|
|
||||||
e => panic!("{:?}", e),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
run(Version::V1);
|
async_std::task::block_on(run(Version::V1));
|
||||||
run(Version::V1Lazy);
|
async_std::task::block_on(run(Version::V1Lazy));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn select_proto_parallel() {
|
fn select_proto_parallel() {
|
||||||
fn run(version: Version) {
|
async fn run(version: Version) {
|
||||||
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let listener_addr = listener.local_addr().unwrap();
|
let listener_addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
let server = listener
|
let server = async_std::task::spawn(async move {
|
||||||
.incoming()
|
let connec = listener.accept().await.unwrap().0;
|
||||||
.into_future()
|
let protos = vec![b"/proto1", b"/proto2"];
|
||||||
.map(|s| s.0.unwrap())
|
let (proto, io) = listener_select_proto(connec, protos).await.unwrap();
|
||||||
.map_err(|(e, _)| e.into())
|
assert_eq!(proto, b"/proto2");
|
||||||
.and_then(move |connec| {
|
io.complete().await.unwrap();
|
||||||
let protos = vec![b"/proto1", b"/proto2"];
|
});
|
||||||
listener_select_proto(connec, protos)
|
|
||||||
})
|
|
||||||
.and_then(|(proto, io)| io.complete().map(move |_| proto));
|
|
||||||
|
|
||||||
let client = TcpStream::connect(&listener_addr)
|
let client = async_std::task::spawn(async move {
|
||||||
.from_err()
|
let connec = TcpStream::connect(&listener_addr).await.unwrap();
|
||||||
.and_then(move |connec| {
|
let protos = vec![b"/proto3", b"/proto2"];
|
||||||
let protos = vec![b"/proto3", b"/proto2"];
|
let (proto, io) = dialer_select_proto_parallel(connec, protos.into_iter(), version)
|
||||||
dialer_select_proto_parallel(connec, protos.into_iter(), version)
|
.await.unwrap();
|
||||||
})
|
assert_eq!(proto, b"/proto2");
|
||||||
.and_then(|(proto, io)| io.complete().map(move |_| proto));
|
io.complete().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
let mut rt = Runtime::new().unwrap();
|
server.await;
|
||||||
let (dialer_chosen, listener_chosen) =
|
client.await;
|
||||||
rt.block_on(client.join(server)).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(dialer_chosen, b"/proto2");
|
|
||||||
assert_eq!(listener_chosen, b"/proto2");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
run(Version::V1);
|
async_std::task::block_on(run(Version::V1));
|
||||||
run(Version::V1Lazy);
|
async_std::task::block_on(run(Version::V1Lazy));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn select_proto_serial() {
|
fn select_proto_serial() {
|
||||||
fn run(version: Version) {
|
async fn run(version: Version) {
|
||||||
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
|
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||||
let listener_addr = listener.local_addr().unwrap();
|
let listener_addr = listener.local_addr().unwrap();
|
||||||
|
|
||||||
let server = listener
|
let server = async_std::task::spawn(async move {
|
||||||
.incoming()
|
let connec = listener.accept().await.unwrap().0;
|
||||||
.into_future()
|
let protos = vec![b"/proto1", b"/proto2"];
|
||||||
.map(|s| s.0.unwrap())
|
let (proto, io) = listener_select_proto(connec, protos).await.unwrap();
|
||||||
.map_err(|(e, _)| e.into())
|
assert_eq!(proto, b"/proto2");
|
||||||
.and_then(move |connec| {
|
io.complete().await.unwrap();
|
||||||
let protos = vec![b"/proto1", b"/proto2"];
|
});
|
||||||
listener_select_proto(connec, protos)
|
|
||||||
})
|
|
||||||
.and_then(|(proto, io)| io.complete().map(move |_| proto));
|
|
||||||
|
|
||||||
let client = TcpStream::connect(&listener_addr)
|
let client = async_std::task::spawn(async move {
|
||||||
.from_err()
|
let connec = TcpStream::connect(&listener_addr).await.unwrap();
|
||||||
.and_then(move |connec| {
|
let protos = vec![b"/proto3", b"/proto2"];
|
||||||
let protos = vec![b"/proto3", b"/proto2"];
|
let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version)
|
||||||
dialer_select_proto_serial(connec, protos.into_iter(), version)
|
.await.unwrap();
|
||||||
})
|
assert_eq!(proto, b"/proto2");
|
||||||
.and_then(|(proto, io)| io.complete().map(move |_| proto));
|
io.complete().await.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
let mut rt = Runtime::new().unwrap();
|
server.await;
|
||||||
let (dialer_chosen, listener_chosen) =
|
client.await;
|
||||||
rt.block_on(client.join(server)).unwrap();
|
|
||||||
|
|
||||||
assert_eq!(dialer_chosen, b"/proto2");
|
|
||||||
assert_eq!(listener_chosen, b"/proto2");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
run(Version::V1);
|
async_std::task::block_on(run(Version::V1));
|
||||||
run(Version::V1Lazy);
|
async_std::task::block_on(run(Version::V1Lazy));
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user