*: Format with rustfmt (#2188)

Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
Max Inden
2021-08-11 13:12:12 +02:00
committed by GitHub
parent 008561283e
commit f701b24ec0
171 changed files with 10051 additions and 7193 deletions

View File

@ -20,11 +20,16 @@
//! Protocol negotiation strategies for the peer acting as the dialer.
use crate::protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError};
use crate::{Negotiated, NegotiationError, Version};
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine};
use futures::{future::Either, prelude::*};
use std::{convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}};
use std::{
convert::TryFrom as _,
iter, mem,
pin::Pin,
task::{Context, Poll},
};
/// Returns a `Future` that negotiates a protocol on the given I/O stream
/// for a peer acting as the _dialer_ (or _initiator_).
@ -48,17 +53,17 @@ use std::{convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}};
pub fn dialer_select_proto<R, I>(
inner: R,
protocols: I,
version: Version
version: Version,
) -> DialerSelectFuture<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
I::Item: AsRef<[u8]>,
{
let iter = protocols.into_iter();
// We choose between the "serial" and "parallel" strategies based on the number of protocols.
if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
Either::Left(dialer_select_proto_serial(inner, iter, version))
Either::Left(dialer_select_proto_serial(inner, iter, version))
} else {
Either::Right(dialer_select_proto_parallel(inner, iter, version))
}
@ -79,12 +84,12 @@ pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPa
pub(crate) fn dialer_select_proto_serial<R, I>(
inner: R,
protocols: I,
version: Version
version: Version,
) -> DialerSelectSeq<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
I::Item: AsRef<[u8]>,
{
let protocols = protocols.into_iter().peekable();
DialerSelectSeq {
@ -92,7 +97,7 @@ where
protocols,
state: SeqState::SendHeader {
io: MessageIO::new(inner),
}
},
}
}
@ -108,20 +113,20 @@ where
pub(crate) fn dialer_select_proto_parallel<R, I>(
inner: R,
protocols: I,
version: Version
version: Version,
) -> DialerSelectPar<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
I::Item: AsRef<[u8]>,
{
let protocols = protocols.into_iter();
DialerSelectPar {
version,
protocols,
state: ParState::SendHeader {
io: MessageIO::new(inner)
}
io: MessageIO::new(inner),
},
}
}
@ -136,11 +141,11 @@ pub struct DialerSelectSeq<R, I: Iterator> {
}
enum SeqState<R, N> {
SendHeader { io: MessageIO<R>, },
SendHeader { io: MessageIO<R> },
SendProtocol { io: MessageIO<R>, protocol: N },
FlushProtocol { io: MessageIO<R>, protocol: N },
AwaitProtocol { io: MessageIO<R>, protocol: N },
Done
Done,
}
impl<R, I> Future for DialerSelectSeq<R, I>
@ -149,7 +154,7 @@ where
// It also makes the implementation considerably easier to write.
R: AsyncRead + AsyncWrite + Unpin,
I: Iterator,
I::Item: AsRef<[u8]>
I::Item: AsRef<[u8]>,
{
type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
@ -160,11 +165,11 @@ where
match mem::replace(this.state, SeqState::Done) {
SeqState::SendHeader { mut io } => {
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Ready(()) => {}
Poll::Pending => {
*this.state = SeqState::SendHeader { io };
return Poll::Pending
},
return Poll::Pending;
}
}
let h = HeaderLine::from(*this.version);
@ -181,11 +186,11 @@ where
SeqState::SendProtocol { mut io, protocol } => {
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Ready(()) => {}
Poll::Pending => {
*this.state = SeqState::SendProtocol { io, protocol };
return Poll::Pending
},
return Poll::Pending;
}
}
let p = Protocol::try_from(protocol.as_ref())?;
@ -207,7 +212,7 @@ where
log::debug!("Dialer: Expecting proposed protocol: {}", p);
let hl = HeaderLine::from(Version::V1Lazy);
let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
return Poll::Ready(Ok((protocol, io)))
return Poll::Ready(Ok((protocol, io)));
}
}
}
@ -218,8 +223,8 @@ where
Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol },
Poll::Pending => {
*this.state = SeqState::FlushProtocol { io, protocol };
return Poll::Pending
},
return Poll::Pending;
}
}
}
@ -228,7 +233,7 @@ where
Poll::Ready(Some(msg)) => msg,
Poll::Pending => {
*this.state = SeqState::AwaitProtocol { io, protocol };
return Poll::Pending
return Poll::Pending;
}
// Treat EOF error as [`NegotiationError::Failed`], not as
// [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
@ -246,16 +251,18 @@ where
return Poll::Ready(Ok((protocol, io)));
}
Message::NotAvailable => {
log::debug!("Dialer: Received rejection of protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
log::debug!(
"Dialer: Received rejection of protocol: {}",
String::from_utf8_lossy(protocol.as_ref())
);
let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
*this.state = SeqState::SendProtocol { io, protocol }
}
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
}
}
SeqState::Done => panic!("SeqState::poll called after completion")
SeqState::Done => panic!("SeqState::poll called after completion"),
}
}
}
@ -277,7 +284,7 @@ enum ParState<R, N> {
Flush { io: MessageIO<R> },
RecvProtocols { io: MessageIO<R> },
SendProtocol { io: MessageIO<R>, protocol: N },
Done
Done,
}
impl<R, I> Future for DialerSelectPar<R, I>
@ -286,7 +293,7 @@ where
// It also makes the implementation considerably easier to write.
R: AsyncRead + AsyncWrite + Unpin,
I: Iterator,
I::Item: AsRef<[u8]>
I::Item: AsRef<[u8]>,
{
type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
@ -297,11 +304,11 @@ where
match mem::replace(this.state, ParState::Done) {
ParState::SendHeader { mut io } => {
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Ready(()) => {}
Poll::Pending => {
*this.state = ParState::SendHeader { io };
return Poll::Pending
},
return Poll::Pending;
}
}
let msg = Message::Header(HeaderLine::from(*this.version));
@ -314,11 +321,11 @@ where
ParState::SendProtocolsRequest { mut io } => {
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Ready(()) => {}
Poll::Pending => {
*this.state = ParState::SendProtocolsRequest { io };
return Poll::Pending
},
return Poll::Pending;
}
}
if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) {
@ -329,22 +336,20 @@ where
*this.state = ParState::Flush { io }
}
ParState::Flush { mut io } => {
match Pin::new(&mut io).poll_flush(cx)? {
Poll::Ready(()) => *this.state = ParState::RecvProtocols { io },
Poll::Pending => {
*this.state = ParState::Flush { io };
return Poll::Pending
},
ParState::Flush { mut io } => match Pin::new(&mut io).poll_flush(cx)? {
Poll::Ready(()) => *this.state = ParState::RecvProtocols { io },
Poll::Pending => {
*this.state = ParState::Flush { io };
return Poll::Pending;
}
}
},
ParState::RecvProtocols { mut io } => {
let msg = match Pin::new(&mut io).poll_next(cx)? {
Poll::Ready(Some(msg)) => msg,
Poll::Pending => {
*this.state = ParState::RecvProtocols { io };
return Poll::Pending
return Poll::Pending;
}
// Treat EOF error as [`NegotiationError::Failed`], not as
// [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
@ -357,12 +362,15 @@ where
*this.state = ParState::RecvProtocols { io }
}
Message::Protocols(supported) => {
let protocol = this.protocols.by_ref()
.find(|p| supported.iter().any(|s|
s.as_ref() == p.as_ref()))
let protocol = this
.protocols
.by_ref()
.find(|p| supported.iter().any(|s| s.as_ref() == p.as_ref()))
.ok_or(NegotiationError::Failed)?;
log::debug!("Dialer: Found supported protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
log::debug!(
"Dialer: Found supported protocol: {}",
String::from_utf8_lossy(protocol.as_ref())
);
*this.state = ParState::SendProtocol { io, protocol };
}
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
@ -371,11 +379,11 @@ where
ParState::SendProtocol { mut io, protocol } => {
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {},
Poll::Ready(()) => {}
Poll::Pending => {
*this.state = ParState::SendProtocol { io, protocol };
return Poll::Pending
},
return Poll::Pending;
}
}
let p = Protocol::try_from(protocol.as_ref())?;
@ -386,10 +394,10 @@ where
log::debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p, None);
return Poll::Ready(Ok((protocol, io)))
return Poll::Ready(Ok((protocol, io)));
}
ParState::Done => panic!("ParState::poll called after completion")
ParState::Done => panic!("ParState::poll called after completion"),
}
}
}

View File

@ -18,9 +18,15 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use bytes::{Bytes, BytesMut, Buf as _, BufMut as _};
use futures::{prelude::*, io::IoSlice};
use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16};
use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
use futures::{io::IoSlice, prelude::*};
use std::{
convert::TryFrom as _,
io,
pin::Pin,
task::{Context, Poll},
u16,
};
const MAX_LEN_BYTES: u16 = 2;
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
@ -50,7 +56,10 @@ pub struct LengthDelimited<R> {
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum ReadState {
/// We are currently reading the length of the next frame of data.
ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize },
ReadLength {
buf: [u8; MAX_LEN_BYTES as usize],
pos: usize,
},
/// We are currently reading the frame of data itself.
ReadData { len: u16, pos: usize },
}
@ -59,7 +68,7 @@ impl Default for ReadState {
fn default() -> Self {
ReadState::ReadLength {
buf: [0; MAX_LEN_BYTES as usize],
pos: 0
pos: 0,
}
}
}
@ -106,10 +115,12 @@ impl<R> LengthDelimited<R> {
///
/// After this method returns `Poll::Ready`, the write buffer of frames
/// submitted to the `Sink` is guaranteed to be empty.
pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>)
-> Poll<Result<(), io::Error>>
pub fn poll_write_buffer(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), io::Error>>
where
R: AsyncWrite
R: AsyncWrite,
{
let mut this = self.project();
@ -119,7 +130,8 @@ impl<R> LengthDelimited<R> {
Poll::Ready(Ok(0)) => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::WriteZero,
"Failed to write buffered frame.")))
"Failed to write buffered frame.",
)))
}
Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
@ -132,7 +144,7 @@ impl<R> LengthDelimited<R> {
impl<R> Stream for LengthDelimited<R>
where
R: AsyncRead
R: AsyncRead,
{
type Item = Result<Bytes, io::Error>;
@ -142,7 +154,7 @@ where
loop {
match this.read_state {
ReadState::ReadLength { buf, pos } => {
match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) {
match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) {
Poll::Ready(Ok(0)) => {
if *pos == 0 {
return Poll::Ready(None);
@ -160,11 +172,10 @@ where
if (buf[*pos - 1] & 0x80) == 0 {
// MSB is not set, indicating the end of the length prefix.
let (len, _) = unsigned_varint::decode::u16(buf)
.map_err(|e| {
log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;
let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| {
log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;
if len >= 1 {
*this.read_state = ReadState::ReadData { len, pos: 0 };
@ -179,12 +190,19 @@ where
// See the module documentation about the max frame len.
return Poll::Ready(Some(Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame length exceeded"))));
"Maximum frame length exceeded",
))));
}
}
ReadState::ReadData { len, pos } => {
match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
match this
.inner
.as_mut()
.poll_read(cx, &mut this.read_buffer[*pos..])
{
Poll::Ready(Ok(0)) => {
return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())))
}
Poll::Ready(Ok(n)) => *pos += n,
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
@ -214,7 +232,7 @@ where
// implied to be roughly 2 * MAX_FRAME_SIZE.
if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
match self.as_mut().poll_write_buffer(cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
@ -233,7 +251,8 @@ where
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame size exceeded."))
"Maximum frame size exceeded.",
))
}
};
@ -249,7 +268,7 @@ where
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// Write all buffered frame data to the underlying I/O stream.
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
@ -264,7 +283,7 @@ where
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
// Write all buffered frame data to the underlying I/O stream.
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
@ -283,7 +302,7 @@ where
#[derive(Debug)]
pub struct LengthDelimitedReader<R> {
#[pin]
inner: LengthDelimited<R>
inner: LengthDelimited<R>,
}
impl<R> LengthDelimitedReader<R> {
@ -306,7 +325,7 @@ impl<R> LengthDelimitedReader<R> {
impl<R> Stream for LengthDelimitedReader<R>
where
R: AsyncRead
R: AsyncRead,
{
type Item = Result<Bytes, io::Error>;
@ -317,17 +336,19 @@ where
impl<R> AsyncWrite for LengthDelimitedReader<R>
where
R: AsyncWrite
R: AsyncWrite,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8])
-> Poll<Result<usize, io::Error>>
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
// `this` here designates the `LengthDelimited`.
let mut this = self.project().inner;
// We need to flush any data previously written with the `LengthDelimited`.
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
@ -344,15 +365,17 @@ where
self.project().inner.poll_close(cx)
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
-> Poll<Result<usize, io::Error>>
{
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
// `this` here designates the `LengthDelimited`.
let mut this = self.project().inner;
// We need to flush any data previously written with the `LengthDelimited`.
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Pending => return Poll::Pending,
}
@ -366,7 +389,7 @@ where
mod tests {
use crate::length_delimited::LengthDelimited;
use async_std::net::{TcpListener, TcpStream};
use futures::{prelude::*, io::Cursor};
use futures::{io::Cursor, prelude::*};
use quickcheck::*;
use std::io::ErrorKind;
@ -394,9 +417,7 @@ mod tests {
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter());
let mut framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(async move {
framed.next().await
}).unwrap();
let recved = futures::executor::block_on(async move { framed.next().await }).unwrap();
assert_eq!(recved.unwrap(), frame);
}
@ -405,9 +426,7 @@ mod tests {
let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0));
let mut framed = LengthDelimited::new(Cursor::new(data));
let recved = futures::executor::block_on(async move {
framed.next().await.unwrap()
});
let recved = futures::executor::block_on(async move { framed.next().await.unwrap() });
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
@ -479,7 +498,8 @@ mod tests {
let expected_frames = frames.clone();
let server = async_std::task::spawn(async move {
let socket = listener.accept().await.unwrap().0;
let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
let mut connec =
rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
let mut buf = vec![0u8; 0];
for expected in expected_frames {

View File

@ -94,10 +94,10 @@ mod negotiated;
mod protocol;
mod tests;
pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError};
pub use self::protocol::ProtocolError;
pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture};
pub use self::listener_select::{listener_select_proto, ListenerSelectFuture};
pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError};
pub use self::protocol::ProtocolError;
/// Supported multistream-select versions.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@ -145,4 +145,4 @@ impl Default for Version {
fn default() -> Self {
Version::V1
}
}
}

View File

@ -21,12 +21,18 @@
//! Protocol negotiation strategies for the peer acting as the listener
//! in a multistream-select protocol negotiation.
use crate::protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError};
use crate::{Negotiated, NegotiationError};
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine};
use futures::prelude::*;
use smallvec::SmallVec;
use std::{convert::TryFrom as _, iter::FromIterator, mem, pin::Pin, task::{Context, Poll}};
use std::{
convert::TryFrom as _,
iter::FromIterator,
mem,
pin::Pin,
task::{Context, Poll},
};
/// Returns a `Future` that negotiates a protocol on the given I/O stream
/// for a peer acting as the _listener_ (or _responder_).
@ -35,28 +41,29 @@ use std::{convert::TryFrom as _, iter::FromIterator, mem, pin::Pin, task::{Conte
/// computation that performs the protocol negotiation with the remote. The
/// returned `Future` resolves with the name of the negotiated protocol and
/// a [`Negotiated`] I/O stream.
pub fn listener_select_proto<R, I>(
inner: R,
protocols: I,
) -> ListenerSelectFuture<R, I::Item>
pub fn listener_select_proto<R, I>(inner: R, protocols: I) -> ListenerSelectFuture<R, I::Item>
where
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
I::Item: AsRef<[u8]>,
{
let protocols = protocols.into_iter().filter_map(|n|
match Protocol::try_from(n.as_ref()) {
let protocols = protocols
.into_iter()
.filter_map(|n| match Protocol::try_from(n.as_ref()) {
Ok(p) => Some((n, p)),
Err(e) => {
log::warn!("Listener: Ignoring invalid protocol: {} due to {}",
String::from_utf8_lossy(n.as_ref()), e);
log::warn!(
"Listener: Ignoring invalid protocol: {} due to {}",
String::from_utf8_lossy(n.as_ref()),
e
);
None
}
});
ListenerSelectFuture {
protocols: SmallVec::from_iter(protocols),
state: State::RecvHeader {
io: MessageIO::new(inner)
io: MessageIO::new(inner),
},
last_sent_na: false,
}
@ -80,19 +87,25 @@ pub struct ListenerSelectFuture<R, N> {
}
enum State<R, N> {
RecvHeader { io: MessageIO<R> },
SendHeader { io: MessageIO<R> },
RecvMessage { io: MessageIO<R> },
RecvHeader {
io: MessageIO<R>,
},
SendHeader {
io: MessageIO<R>,
},
RecvMessage {
io: MessageIO<R>,
},
SendMessage {
io: MessageIO<R>,
message: Message,
protocol: Option<N>
protocol: Option<N>,
},
Flush {
io: MessageIO<R>,
protocol: Option<N>
protocol: Option<N>,
},
Done
Done,
}
impl<R, N> Future for ListenerSelectFuture<R, N>
@ -100,7 +113,7 @@ where
// The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
// It also makes the implementation considerably easier to write.
R: AsyncRead + AsyncWrite + Unpin,
N: AsRef<[u8]> + Clone
N: AsRef<[u8]> + Clone,
{
type Output = Result<(N, Negotiated<R>), NegotiationError>;
@ -111,14 +124,12 @@ where
match mem::replace(this.state, State::Done) {
State::RecvHeader { mut io } => {
match io.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(Message::Header(h)))) => {
match h {
HeaderLine::V1 => *this.state = State::SendHeader { io }
}
}
Poll::Ready(Some(Ok(Message::Header(h)))) => match h {
HeaderLine::V1 => *this.state = State::SendHeader { io },
},
Poll::Ready(Some(Ok(_))) => {
return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
},
}
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
// Treat EOF error as [`NegotiationError::Failed`], not as
// [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
@ -126,7 +137,7 @@ where
Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
Poll::Pending => {
*this.state = State::RecvHeader { io };
return Poll::Pending
return Poll::Pending;
}
}
}
@ -135,9 +146,9 @@ where
match Pin::new(&mut io).poll_ready(cx) {
Poll::Pending => {
*this.state = State::SendHeader { io };
return Poll::Pending
},
Poll::Ready(Ok(())) => {},
return Poll::Pending;
}
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
@ -175,28 +186,37 @@ where
// the dialer also raises `NegotiationError::Failed` when finally
// reading the `N/A` response.
if let ProtocolError::InvalidMessage = &err {
log::trace!("Listener: Negotiation failed with invalid \
message after protocol rejection.");
return Poll::Ready(Err(NegotiationError::Failed))
log::trace!(
"Listener: Negotiation failed with invalid \
message after protocol rejection."
);
return Poll::Ready(Err(NegotiationError::Failed));
}
if let ProtocolError::IoError(e) = &err {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
log::trace!("Listener: Negotiation failed with EOF \
after protocol rejection.");
return Poll::Ready(Err(NegotiationError::Failed))
log::trace!(
"Listener: Negotiation failed with EOF \
after protocol rejection."
);
return Poll::Ready(Err(NegotiationError::Failed));
}
}
}
return Poll::Ready(Err(From::from(err)))
return Poll::Ready(Err(From::from(err)));
}
};
match msg {
Message::ListProtocols => {
let supported = this.protocols.iter().map(|(_,p)| p).cloned().collect();
let supported =
this.protocols.iter().map(|(_, p)| p).cloned().collect();
let message = Message::Protocols(supported);
*this.state = State::SendMessage { io, message, protocol: None }
*this.state = State::SendMessage {
io,
message,
protocol: None,
}
}
Message::Protocol(p) => {
let protocol = this.protocols.iter().find_map(|(name, proto)| {
@ -211,28 +231,42 @@ where
log::debug!("Listener: confirming protocol: {}", p);
Message::Protocol(p.clone())
} else {
log::debug!("Listener: rejecting protocol: {}",
String::from_utf8_lossy(p.as_ref()));
log::debug!(
"Listener: rejecting protocol: {}",
String::from_utf8_lossy(p.as_ref())
);
Message::NotAvailable
};
*this.state = State::SendMessage { io, message, protocol };
*this.state = State::SendMessage {
io,
message,
protocol,
};
}
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
}
}
State::SendMessage { mut io, message, protocol } => {
State::SendMessage {
mut io,
message,
protocol,
} => {
match Pin::new(&mut io).poll_ready(cx) {
Poll::Pending => {
*this.state = State::SendMessage { io, message, protocol };
return Poll::Pending
},
Poll::Ready(Ok(())) => {},
*this.state = State::SendMessage {
io,
message,
protocol,
};
return Poll::Pending;
}
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
if let Message::NotAvailable = &message {
if let Message::NotAvailable = &message {
*this.last_sent_na = true;
} else {
*this.last_sent_na = false;
@ -249,26 +283,28 @@ where
match Pin::new(&mut io).poll_flush(cx) {
Poll::Pending => {
*this.state = State::Flush { io, protocol };
return Poll::Pending
},
return Poll::Pending;
}
Poll::Ready(Ok(())) => {
// If a protocol has been selected, finish negotiation.
// Otherwise expect to receive another message.
match protocol {
Some(protocol) => {
log::debug!("Listener: sent confirmed protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
log::debug!(
"Listener: sent confirmed protocol: {}",
String::from_utf8_lossy(protocol.as_ref())
);
let io = Negotiated::completed(io.into_inner());
return Poll::Ready(Ok((protocol, io)))
return Poll::Ready(Ok((protocol, io)));
}
None => *this.state = State::RecvMessage { io }
None => *this.state = State::RecvMessage { io },
}
}
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
}
State::Done => panic!("State::poll called after completion")
State::Done => panic!("State::poll called after completion"),
}
}
}

View File

@ -18,11 +18,20 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use crate::protocol::{Protocol, MessageReader, Message, ProtocolError, HeaderLine};
use crate::protocol::{HeaderLine, Message, MessageReader, Protocol, ProtocolError};
use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
use futures::{
io::{IoSlice, IoSliceMut},
prelude::*,
ready,
};
use pin_project::pin_project;
use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
use std::{
error::Error,
fmt, io, mem,
pin::Pin,
task::{Context, Poll},
};
/// An I/O stream that has settled on an (application-layer) protocol to use.
///
@ -39,7 +48,7 @@ use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
#[derive(Debug)]
pub struct Negotiated<TInner> {
#[pin]
state: State<TInner>
state: State<TInner>,
}
/// A `Future` that waits on the completion of protocol negotiation.
@ -57,12 +66,15 @@ where
type Output = Result<Negotiated<TInner>, NegotiationError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
let mut io = self
.inner
.take()
.expect("NegotiatedFuture called after completion.");
match Negotiated::poll(Pin::new(&mut io), cx) {
Poll::Pending => {
self.inner = Some(io);
Poll::Pending
},
}
Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
Poll::Ready(Err(err)) => {
self.inner = Some(io);
@ -75,7 +87,9 @@ where
impl<TInner> Negotiated<TInner> {
/// Creates a `Negotiated` in state [`State::Completed`].
pub(crate) fn completed(io: TInner) -> Self {
Negotiated { state: State::Completed { io } }
Negotiated {
state: State::Completed { io },
}
}
/// Creates a `Negotiated` in state [`State::Expecting`] that is still
@ -83,25 +97,31 @@ impl<TInner> Negotiated<TInner> {
pub(crate) fn expecting(
io: MessageReader<TInner>,
protocol: Protocol,
header: Option<HeaderLine>
header: Option<HeaderLine>,
) -> Self {
Negotiated { state: State::Expecting { io, protocol, header } }
Negotiated {
state: State::Expecting {
io,
protocol,
header,
},
}
}
/// Polls the `Negotiated` for completion.
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
where
TInner: AsyncRead + AsyncWrite + Unpin
TInner: AsyncRead + AsyncWrite + Unpin,
{
// Flush any pending negotiation data.
match self.as_mut().poll_flush(cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Ok(())) => {}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
// If the remote closed the stream, it is important to still
// continue reading the data that was sent, if any.
if e.kind() != io::ErrorKind::WriteZero {
return Poll::Ready(Err(e.into()))
return Poll::Ready(Err(e.into()));
}
}
}
@ -109,36 +129,52 @@ impl<TInner> Negotiated<TInner> {
let mut this = self.project();
if let StateProj::Completed { .. } = this.state.as_mut().project() {
return Poll::Ready(Ok(()));
return Poll::Ready(Ok(()));
}
// Read outstanding protocol negotiation messages.
loop {
match mem::replace(&mut *this.state, State::Invalid) {
State::Expecting { mut io, header, protocol } => {
State::Expecting {
mut io,
header,
protocol,
} => {
let msg = match Pin::new(&mut io).poll_next(cx)? {
Poll::Ready(Some(msg)) => msg,
Poll::Pending => {
*this.state = State::Expecting { io, header, protocol };
return Poll::Pending
},
*this.state = State::Expecting {
io,
header,
protocol,
};
return Poll::Pending;
}
Poll::Ready(None) => {
return Poll::Ready(Err(ProtocolError::IoError(
io::ErrorKind::UnexpectedEof.into()).into()));
io::ErrorKind::UnexpectedEof.into(),
)
.into()));
}
};
if let Message::Header(h) = &msg {
if Some(h) == header.as_ref() {
*this.state = State::Expecting { io, protocol, header: None };
continue
*this.state = State::Expecting {
io,
protocol,
header: None,
};
continue;
}
}
if let Message::Protocol(p) = &msg {
if p.as_ref() == protocol.as_ref() {
log::debug!("Negotiated: Received confirmation for protocol: {}", p);
*this.state = State::Completed { io: io.into_inner() };
*this.state = State::Completed {
io: io.into_inner(),
};
return Poll::Ready(Ok(()));
}
}
@ -146,7 +182,7 @@ impl<TInner> Negotiated<TInner> {
return Poll::Ready(Err(NegotiationError::Failed));
}
_ => panic!("Negotiated: Invalid state")
_ => panic!("Negotiated: Invalid state"),
}
}
}
@ -178,7 +214,10 @@ enum State<R> {
/// In this state, a protocol has been agreed upon and I/O
/// on the underlying stream can commence.
Completed { #[pin] io: R },
Completed {
#[pin]
io: R,
},
/// Temporary state while moving the `io` resource from
/// `Expecting` to `Completed`.
@ -187,11 +226,13 @@ enum State<R> {
impl<TInner> AsyncRead for Negotiated<TInner>
where
TInner: AsyncRead + AsyncWrite + Unpin
TInner: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8])
-> Poll<Result<usize, io::Error>>
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
loop {
if let StateProj::Completed { io } = self.as_mut().project().state.project() {
// If protocol negotiation is complete, commence with reading.
@ -201,7 +242,7 @@ where
// Poll the `Negotiated`, driving protocol negotiation to completion,
// including flushing of any remaining data.
match self.as_mut().poll(cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Ok(())) => {}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
@ -217,19 +258,21 @@ where
}
}*/
fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>])
-> Poll<Result<usize, io::Error>>
{
fn poll_read_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<Result<usize, io::Error>> {
loop {
if let StateProj::Completed { io } = self.as_mut().project().state.project() {
// If protocol negotiation is complete, commence with reading.
return io.poll_read_vectored(cx, bufs)
return io.poll_read_vectored(cx, bufs);
}
// Poll the `Negotiated`, driving protocol negotiation to completion,
// including flushing of any remaining data.
match self.as_mut().poll(cx) {
Poll::Ready(Ok(())) => {},
Poll::Ready(Ok(())) => {}
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
@ -239,9 +282,13 @@ where
impl<TInner> AsyncWrite for Negotiated<TInner>
where
TInner: AsyncWrite + AsyncRead + Unpin
TInner: AsyncWrite + AsyncRead + Unpin,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
match self.project().state.project() {
StateProj::Completed { io } => io.poll_write(cx, buf),
StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
@ -261,7 +308,10 @@ where
// Ensure all data has been flushed and expected negotiation messages
// have been received.
ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
ready!(self
.as_mut()
.poll_flush(cx)
.map_err(Into::<io::Error>::into)?);
// Continue with the shutdown of the underlying I/O stream.
match self.project().state.project() {
@ -271,9 +321,11 @@ where
}
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
-> Poll<Result<usize, io::Error>>
{
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
match self.project().state.project() {
StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
@ -307,7 +359,7 @@ impl From<io::Error> for NegotiationError {
impl From<NegotiationError> for io::Error {
fn from(err: NegotiationError) -> io::Error {
if let NegotiationError::ProtocolError(e) = err {
return e.into()
return e.into();
}
io::Error::new(io::ErrorKind::Other, err)
}
@ -325,10 +377,10 @@ impl Error for NegotiationError {
impl fmt::Display for NegotiationError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
NegotiationError::ProtocolError(p) =>
fmt.write_fmt(format_args!("Protocol error: {}", p)),
NegotiationError::Failed =>
fmt.write_str("Protocol negotiation failed.")
NegotiationError::ProtocolError(p) => {
fmt.write_fmt(format_args!("Protocol error: {}", p))
}
NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."),
}
}
}

View File

@ -25,12 +25,18 @@
//! `Stream` and `Sink` implementations of `MessageIO` and
//! `MessageReader`.
use crate::Version;
use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
use crate::Version;
use bytes::{Bytes, BytesMut, BufMut};
use futures::{prelude::*, io::IoSlice, ready};
use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}};
use bytes::{BufMut, Bytes, BytesMut};
use futures::{io::IoSlice, prelude::*, ready};
use std::{
convert::TryFrom,
error::Error,
fmt, io,
pin::Pin,
task::{Context, Poll},
};
use unsigned_varint as uvi;
/// The maximum number of supported protocols that can be processed.
@ -75,7 +81,7 @@ impl TryFrom<Bytes> for Protocol {
fn try_from(value: Bytes) -> Result<Self, Self::Error> {
if !value.as_ref().starts_with(b"/") {
return Err(ProtocolError::InvalidProtocol)
return Err(ProtocolError::InvalidProtocol);
}
Ok(Protocol(value))
}
@ -160,7 +166,7 @@ impl Message {
/// Decodes a `Message` from its byte representation.
pub fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
if msg == MSG_MULTISTREAM_1_0 {
return Ok(Message::Header(HeaderLine::V1))
return Ok(Message::Header(HeaderLine::V1));
}
if msg == MSG_PROTOCOL_NA {
@ -168,13 +174,14 @@ impl Message {
}
if msg == MSG_LS {
return Ok(Message::ListProtocols)
return Ok(Message::ListProtocols);
}
// If it starts with a `/`, ends with a line feed without any
// other line feeds in-between, it must be a protocol name.
if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') &&
!msg[.. msg.len() - 1].contains(&b'\n')
if msg.get(0) == Some(&b'/')
&& msg.last() == Some(&b'\n')
&& !msg[..msg.len() - 1].contains(&b'\n')
{
let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
return Ok(Message::Protocol(p));
@ -187,24 +194,24 @@ impl Message {
loop {
// A well-formed message must be terminated with a newline.
if remaining == [b'\n'] {
break
break;
} else if protocols.len() == MAX_PROTOCOLS {
return Err(ProtocolError::TooManyProtocols)
return Err(ProtocolError::TooManyProtocols);
}
// Decode the length of the next protocol name and check that
// it ends with a line feed.
let (len, tail) = uvi::decode::usize(remaining)?;
if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
return Err(ProtocolError::InvalidMessage)
return Err(ProtocolError::InvalidMessage);
}
// Parse the protocol name.
let p = Protocol::try_from(Bytes::copy_from_slice(&tail[.. len - 1]))?;
let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?;
protocols.push(p);
// Skip ahead to the next protocol.
remaining = &tail[len ..];
remaining = &tail[len..];
}
Ok(Message::Protocols(protocols))
@ -222,9 +229,11 @@ impl<R> MessageIO<R> {
/// Constructs a new `MessageIO` resource wrapping the given I/O stream.
pub fn new(inner: R) -> MessageIO<R>
where
R: AsyncRead + AsyncWrite
R: AsyncRead + AsyncWrite,
{
Self { inner: LengthDelimited::new(inner) }
Self {
inner: LengthDelimited::new(inner),
}
}
/// Converts the [`MessageIO`] into a [`MessageReader`], dropping the
@ -235,7 +244,9 @@ impl<R> MessageIO<R> {
/// received but no more messages are written, allowing the writing of
/// follow-up protocol data to commence.
pub fn into_reader(self) -> MessageReader<R> {
MessageReader { inner: self.inner.into_reader() }
MessageReader {
inner: self.inner.into_reader(),
}
}
/// Drops the [`MessageIO`] resource, yielding the underlying I/O stream.
@ -265,7 +276,10 @@ where
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
let mut buf = BytesMut::new();
item.encode(&mut buf)?;
self.project().inner.start_send(buf.freeze()).map_err(From::from)
self.project()
.inner
.start_send(buf.freeze())
.map_err(From::from)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@ -279,7 +293,7 @@ where
impl<R> Stream for MessageIO<R>
where
R: AsyncRead
R: AsyncRead,
{
type Item = Result<Message, ProtocolError>;
@ -299,7 +313,7 @@ where
#[derive(Debug)]
pub struct MessageReader<R> {
#[pin]
inner: LengthDelimitedReader<R>
inner: LengthDelimitedReader<R>,
}
impl<R> MessageReader<R> {
@ -321,7 +335,7 @@ impl<R> MessageReader<R> {
impl<R> Stream for MessageReader<R>
where
R: AsyncRead
R: AsyncRead,
{
type Item = Result<Message, ProtocolError>;
@ -332,9 +346,13 @@ where
impl<TInner> AsyncWrite for MessageReader<TInner>
where
TInner: AsyncWrite
TInner: AsyncWrite,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write(cx, buf)
}
@ -346,12 +364,19 @@ where
self.project().inner.poll_close(cx)
}
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize, io::Error>> {
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
}
fn poll_stream<S>(stream: Pin<&mut S>, cx: &mut Context<'_>) -> Poll<Option<Result<Message, ProtocolError>>>
fn poll_stream<S>(
stream: Pin<&mut S>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Message, ProtocolError>>>
where
S: Stream<Item = Result<Bytes, io::Error>>,
{
@ -361,7 +386,7 @@ where
Err(err) => return Poll::Ready(Some(Err(err))),
}
} else {
return Poll::Ready(None)
return Poll::Ready(None);
};
log::trace!("Received message: {:?}", msg);
@ -394,7 +419,7 @@ impl From<io::Error> for ProtocolError {
impl From<ProtocolError> for io::Error {
fn from(err: ProtocolError) -> Self {
if let ProtocolError::IoError(e) = err {
return e
return e;
}
io::ErrorKind::InvalidData.into()
}
@ -418,14 +443,10 @@ impl Error for ProtocolError {
impl fmt::Display for ProtocolError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
ProtocolError::IoError(e) =>
write!(fmt, "I/O error: {}", e),
ProtocolError::InvalidMessage =>
write!(fmt, "Received an invalid message."),
ProtocolError::InvalidProtocol =>
write!(fmt, "A protocol (name) is invalid."),
ProtocolError::TooManyProtocols =>
write!(fmt, "Too many protocols received.")
ProtocolError::IoError(e) => write!(fmt, "I/O error: {}", e),
ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."),
ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."),
ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."),
}
}
}
@ -434,8 +455,8 @@ impl fmt::Display for ProtocolError {
mod tests {
use super::*;
use quickcheck::*;
use rand::Rng;
use rand::distributions::Alphanumeric;
use rand::Rng;
use std::iter;
impl Arbitrary for Protocol {
@ -457,7 +478,7 @@ mod tests {
2 => Message::ListProtocols,
3 => Message::Protocol(Protocol::arbitrary(g)),
4 => Message::Protocols(Vec::arbitrary(g)),
_ => panic!()
_ => panic!(),
}
}
}
@ -466,10 +487,11 @@ mod tests {
fn encode_decode_message() {
fn prop(msg: Message) {
let mut buf = BytesMut::new();
msg.encode(&mut buf).expect(&format!("Encoding message failed: {:?}", msg));
msg.encode(&mut buf)
.expect(&format!("Encoding message failed: {:?}", msg));
match Message::decode(buf.freeze()) {
Ok(m) => assert_eq!(m, msg),
Err(e) => panic!("Decoding failed: {:?}", e)
Err(e) => panic!("Decoding failed: {:?}", e),
}
}
quickcheck(prop as fn(_))

View File

@ -22,9 +22,9 @@
#![cfg(test)]
use crate::{Version, NegotiationError};
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
use crate::{dialer_select_proto, listener_select_proto};
use crate::{NegotiationError, Version};
use async_std::net::{TcpListener, TcpStream};
use futures::prelude::*;
@ -54,7 +54,8 @@ fn select_proto_basic() {
let connec = TcpStream::connect(&listener_addr).await.unwrap();
let protos = vec![b"/proto3", b"/proto2"];
let (proto, mut io) = dialer_select_proto(connec, protos.into_iter(), version)
.await.unwrap();
.await
.unwrap();
assert_eq!(proto, b"/proto2");
io.write_all(b"ping").await.unwrap();
@ -79,12 +80,14 @@ fn select_proto_basic() {
fn negotiation_failed() {
let _ = env_logger::try_init();
async fn run(Test {
version,
listen_protos,
dial_protos,
dial_payload
}: Test) {
async fn run(
Test {
version,
listen_protos,
dial_protos,
dial_payload,
}: Test,
) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let listener_addr = listener.local_addr().unwrap();
@ -93,10 +96,12 @@ fn negotiation_failed() {
let io = match listener_select_proto(connec, listen_protos).await {
Ok((_, io)) => io,
Err(NegotiationError::Failed) => return,
Err(NegotiationError::ProtocolError(e)) => panic!("Unexpected protocol error {}", e),
Err(NegotiationError::ProtocolError(e)) => {
panic!("Unexpected protocol error {}", e)
}
};
match io.complete().await {
Err(NegotiationError::Failed) => {},
Err(NegotiationError::Failed) => {}
_ => panic!(),
}
});
@ -106,14 +111,14 @@ fn negotiation_failed() {
let mut io = match dialer_select_proto(connec, dial_protos.into_iter(), version).await {
Err(NegotiationError::Failed) => return,
Ok((_, io)) => io,
Err(_) => panic!()
Err(_) => panic!(),
};
// The dialer may write a payload that is even sent before it
// got confirmation of the last proposed protocol, when `V1Lazy`
// is used.
io.write_all(&dial_payload).await.unwrap();
match io.complete().await {
Err(NegotiationError::Failed) => {},
Err(NegotiationError::Failed) => {}
_ => panic!(),
}
});
@ -135,10 +140,10 @@ fn negotiation_failed() {
//
// The choices here cover the main distinction between a single
// and multiple protocols.
let protos = vec!{
let protos = vec![
(vec!["/proto1"], vec!["/proto2"]),
(vec!["/proto1", "/proto2"], vec!["/proto3", "/proto4"]),
};
];
// The payloads that the dialer sends after "successful" negotiation,
// which may be sent even before the dialer got protocol confirmation
@ -147,7 +152,7 @@ fn negotiation_failed() {
// The choices here cover the specific situations that can arise with
// `V1Lazy` and which must nevertheless behave identically to `V1` w.r.t.
// the outcome of the negotiation.
let payloads = vec!{
let payloads = vec![
// No payload, in which case all versions should behave identically
// in any case, i.e. the baseline test.
vec![],
@ -155,13 +160,13 @@ fn negotiation_failed() {
// `1` as a message length and encounters an invalid message (the
// second `1`). The listener is nevertheless expected to fail
// negotiation normally, just like with `V1`.
vec![1,1],
vec![1, 1],
// With this payload and `V1Lazy`, the listener interprets the first
// `42` as a message length and encounters unexpected EOF trying to
// read a message of that length. The listener is nevertheless expected
// to fail negotiation normally, just like with `V1`
vec![42,1],
};
vec![42, 1],
];
for (listen_protos, dial_protos) in protos {
for dial_payload in payloads.clone() {
@ -195,7 +200,8 @@ fn select_proto_parallel() {
let connec = TcpStream::connect(&listener_addr).await.unwrap();
let protos = vec![b"/proto3", b"/proto2"];
let (proto, io) = dialer_select_proto_parallel(connec, protos.into_iter(), version)
.await.unwrap();
.await
.unwrap();
assert_eq!(proto, b"/proto2");
io.complete().await.unwrap();
});
@ -226,7 +232,8 @@ fn select_proto_serial() {
let connec = TcpStream::connect(&listener_addr).await.unwrap();
let protos = vec![b"/proto3", b"/proto2"];
let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version)
.await.unwrap();
.await
.unwrap();
assert_eq!(proto, b"/proto2");
io.complete().await.unwrap();
});