mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-06-26 16:21:39 +00:00
*: Format with rustfmt (#2188)
Co-authored-by: Thomas Eizinger <thomas@eizinger.io>
This commit is contained in:
@ -20,11 +20,16 @@
|
||||
|
||||
//! Protocol negotiation strategies for the peer acting as the dialer.
|
||||
|
||||
use crate::protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError};
|
||||
use crate::{Negotiated, NegotiationError, Version};
|
||||
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine};
|
||||
|
||||
use futures::{future::Either, prelude::*};
|
||||
use std::{convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}};
|
||||
use std::{
|
||||
convert::TryFrom as _,
|
||||
iter, mem,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// Returns a `Future` that negotiates a protocol on the given I/O stream
|
||||
/// for a peer acting as the _dialer_ (or _initiator_).
|
||||
@ -48,17 +53,17 @@ use std::{convert::TryFrom as _, iter, mem, pin::Pin, task::{Context, Poll}};
|
||||
pub fn dialer_select_proto<R, I>(
|
||||
inner: R,
|
||||
protocols: I,
|
||||
version: Version
|
||||
version: Version,
|
||||
) -> DialerSelectFuture<R, I::IntoIter>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: IntoIterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
I::Item: AsRef<[u8]>,
|
||||
{
|
||||
let iter = protocols.into_iter();
|
||||
// We choose between the "serial" and "parallel" strategies based on the number of protocols.
|
||||
if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
|
||||
Either::Left(dialer_select_proto_serial(inner, iter, version))
|
||||
Either::Left(dialer_select_proto_serial(inner, iter, version))
|
||||
} else {
|
||||
Either::Right(dialer_select_proto_parallel(inner, iter, version))
|
||||
}
|
||||
@ -79,12 +84,12 @@ pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPa
|
||||
pub(crate) fn dialer_select_proto_serial<R, I>(
|
||||
inner: R,
|
||||
protocols: I,
|
||||
version: Version
|
||||
version: Version,
|
||||
) -> DialerSelectSeq<R, I::IntoIter>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: IntoIterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
I::Item: AsRef<[u8]>,
|
||||
{
|
||||
let protocols = protocols.into_iter().peekable();
|
||||
DialerSelectSeq {
|
||||
@ -92,7 +97,7 @@ where
|
||||
protocols,
|
||||
state: SeqState::SendHeader {
|
||||
io: MessageIO::new(inner),
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -108,20 +113,20 @@ where
|
||||
pub(crate) fn dialer_select_proto_parallel<R, I>(
|
||||
inner: R,
|
||||
protocols: I,
|
||||
version: Version
|
||||
version: Version,
|
||||
) -> DialerSelectPar<R, I::IntoIter>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: IntoIterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
I::Item: AsRef<[u8]>,
|
||||
{
|
||||
let protocols = protocols.into_iter();
|
||||
DialerSelectPar {
|
||||
version,
|
||||
protocols,
|
||||
state: ParState::SendHeader {
|
||||
io: MessageIO::new(inner)
|
||||
}
|
||||
io: MessageIO::new(inner),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,11 +141,11 @@ pub struct DialerSelectSeq<R, I: Iterator> {
|
||||
}
|
||||
|
||||
enum SeqState<R, N> {
|
||||
SendHeader { io: MessageIO<R>, },
|
||||
SendHeader { io: MessageIO<R> },
|
||||
SendProtocol { io: MessageIO<R>, protocol: N },
|
||||
FlushProtocol { io: MessageIO<R>, protocol: N },
|
||||
AwaitProtocol { io: MessageIO<R>, protocol: N },
|
||||
Done
|
||||
Done,
|
||||
}
|
||||
|
||||
impl<R, I> Future for DialerSelectSeq<R, I>
|
||||
@ -149,7 +154,7 @@ where
|
||||
// It also makes the implementation considerably easier to write.
|
||||
R: AsyncRead + AsyncWrite + Unpin,
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
I::Item: AsRef<[u8]>,
|
||||
{
|
||||
type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
|
||||
|
||||
@ -160,11 +165,11 @@ where
|
||||
match mem::replace(this.state, SeqState::Done) {
|
||||
SeqState::SendHeader { mut io } => {
|
||||
match Pin::new(&mut io).poll_ready(cx)? {
|
||||
Poll::Ready(()) => {},
|
||||
Poll::Ready(()) => {}
|
||||
Poll::Pending => {
|
||||
*this.state = SeqState::SendHeader { io };
|
||||
return Poll::Pending
|
||||
},
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
let h = HeaderLine::from(*this.version);
|
||||
@ -181,11 +186,11 @@ where
|
||||
|
||||
SeqState::SendProtocol { mut io, protocol } => {
|
||||
match Pin::new(&mut io).poll_ready(cx)? {
|
||||
Poll::Ready(()) => {},
|
||||
Poll::Ready(()) => {}
|
||||
Poll::Pending => {
|
||||
*this.state = SeqState::SendProtocol { io, protocol };
|
||||
return Poll::Pending
|
||||
},
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
let p = Protocol::try_from(protocol.as_ref())?;
|
||||
@ -207,7 +212,7 @@ where
|
||||
log::debug!("Dialer: Expecting proposed protocol: {}", p);
|
||||
let hl = HeaderLine::from(Version::V1Lazy);
|
||||
let io = Negotiated::expecting(io.into_reader(), p, Some(hl));
|
||||
return Poll::Ready(Ok((protocol, io)))
|
||||
return Poll::Ready(Ok((protocol, io)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -218,8 +223,8 @@ where
|
||||
Poll::Ready(()) => *this.state = SeqState::AwaitProtocol { io, protocol },
|
||||
Poll::Pending => {
|
||||
*this.state = SeqState::FlushProtocol { io, protocol };
|
||||
return Poll::Pending
|
||||
},
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -228,7 +233,7 @@ where
|
||||
Poll::Ready(Some(msg)) => msg,
|
||||
Poll::Pending => {
|
||||
*this.state = SeqState::AwaitProtocol { io, protocol };
|
||||
return Poll::Pending
|
||||
return Poll::Pending;
|
||||
}
|
||||
// Treat EOF error as [`NegotiationError::Failed`], not as
|
||||
// [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
|
||||
@ -246,16 +251,18 @@ where
|
||||
return Poll::Ready(Ok((protocol, io)));
|
||||
}
|
||||
Message::NotAvailable => {
|
||||
log::debug!("Dialer: Received rejection of protocol: {}",
|
||||
String::from_utf8_lossy(protocol.as_ref()));
|
||||
log::debug!(
|
||||
"Dialer: Received rejection of protocol: {}",
|
||||
String::from_utf8_lossy(protocol.as_ref())
|
||||
);
|
||||
let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
|
||||
*this.state = SeqState::SendProtocol { io, protocol }
|
||||
}
|
||||
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
|
||||
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
|
||||
}
|
||||
}
|
||||
|
||||
SeqState::Done => panic!("SeqState::poll called after completion")
|
||||
SeqState::Done => panic!("SeqState::poll called after completion"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -277,7 +284,7 @@ enum ParState<R, N> {
|
||||
Flush { io: MessageIO<R> },
|
||||
RecvProtocols { io: MessageIO<R> },
|
||||
SendProtocol { io: MessageIO<R>, protocol: N },
|
||||
Done
|
||||
Done,
|
||||
}
|
||||
|
||||
impl<R, I> Future for DialerSelectPar<R, I>
|
||||
@ -286,7 +293,7 @@ where
|
||||
// It also makes the implementation considerably easier to write.
|
||||
R: AsyncRead + AsyncWrite + Unpin,
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
I::Item: AsRef<[u8]>,
|
||||
{
|
||||
type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
|
||||
|
||||
@ -297,11 +304,11 @@ where
|
||||
match mem::replace(this.state, ParState::Done) {
|
||||
ParState::SendHeader { mut io } => {
|
||||
match Pin::new(&mut io).poll_ready(cx)? {
|
||||
Poll::Ready(()) => {},
|
||||
Poll::Ready(()) => {}
|
||||
Poll::Pending => {
|
||||
*this.state = ParState::SendHeader { io };
|
||||
return Poll::Pending
|
||||
},
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
let msg = Message::Header(HeaderLine::from(*this.version));
|
||||
@ -314,11 +321,11 @@ where
|
||||
|
||||
ParState::SendProtocolsRequest { mut io } => {
|
||||
match Pin::new(&mut io).poll_ready(cx)? {
|
||||
Poll::Ready(()) => {},
|
||||
Poll::Ready(()) => {}
|
||||
Poll::Pending => {
|
||||
*this.state = ParState::SendProtocolsRequest { io };
|
||||
return Poll::Pending
|
||||
},
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(err) = Pin::new(&mut io).start_send(Message::ListProtocols) {
|
||||
@ -329,22 +336,20 @@ where
|
||||
*this.state = ParState::Flush { io }
|
||||
}
|
||||
|
||||
ParState::Flush { mut io } => {
|
||||
match Pin::new(&mut io).poll_flush(cx)? {
|
||||
Poll::Ready(()) => *this.state = ParState::RecvProtocols { io },
|
||||
Poll::Pending => {
|
||||
*this.state = ParState::Flush { io };
|
||||
return Poll::Pending
|
||||
},
|
||||
ParState::Flush { mut io } => match Pin::new(&mut io).poll_flush(cx)? {
|
||||
Poll::Ready(()) => *this.state = ParState::RecvProtocols { io },
|
||||
Poll::Pending => {
|
||||
*this.state = ParState::Flush { io };
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
ParState::RecvProtocols { mut io } => {
|
||||
let msg = match Pin::new(&mut io).poll_next(cx)? {
|
||||
Poll::Ready(Some(msg)) => msg,
|
||||
Poll::Pending => {
|
||||
*this.state = ParState::RecvProtocols { io };
|
||||
return Poll::Pending
|
||||
return Poll::Pending;
|
||||
}
|
||||
// Treat EOF error as [`NegotiationError::Failed`], not as
|
||||
// [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
|
||||
@ -357,12 +362,15 @@ where
|
||||
*this.state = ParState::RecvProtocols { io }
|
||||
}
|
||||
Message::Protocols(supported) => {
|
||||
let protocol = this.protocols.by_ref()
|
||||
.find(|p| supported.iter().any(|s|
|
||||
s.as_ref() == p.as_ref()))
|
||||
let protocol = this
|
||||
.protocols
|
||||
.by_ref()
|
||||
.find(|p| supported.iter().any(|s| s.as_ref() == p.as_ref()))
|
||||
.ok_or(NegotiationError::Failed)?;
|
||||
log::debug!("Dialer: Found supported protocol: {}",
|
||||
String::from_utf8_lossy(protocol.as_ref()));
|
||||
log::debug!(
|
||||
"Dialer: Found supported protocol: {}",
|
||||
String::from_utf8_lossy(protocol.as_ref())
|
||||
);
|
||||
*this.state = ParState::SendProtocol { io, protocol };
|
||||
}
|
||||
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
|
||||
@ -371,11 +379,11 @@ where
|
||||
|
||||
ParState::SendProtocol { mut io, protocol } => {
|
||||
match Pin::new(&mut io).poll_ready(cx)? {
|
||||
Poll::Ready(()) => {},
|
||||
Poll::Ready(()) => {}
|
||||
Poll::Pending => {
|
||||
*this.state = ParState::SendProtocol { io, protocol };
|
||||
return Poll::Pending
|
||||
},
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
|
||||
let p = Protocol::try_from(protocol.as_ref())?;
|
||||
@ -386,10 +394,10 @@ where
|
||||
log::debug!("Dialer: Expecting proposed protocol: {}", p);
|
||||
let io = Negotiated::expecting(io.into_reader(), p, None);
|
||||
|
||||
return Poll::Ready(Ok((protocol, io)))
|
||||
return Poll::Ready(Ok((protocol, io)));
|
||||
}
|
||||
|
||||
ParState::Done => panic!("ParState::poll called after completion")
|
||||
ParState::Done => panic!("ParState::poll called after completion"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -18,9 +18,15 @@
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use bytes::{Bytes, BytesMut, Buf as _, BufMut as _};
|
||||
use futures::{prelude::*, io::IoSlice};
|
||||
use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16};
|
||||
use bytes::{Buf as _, BufMut as _, Bytes, BytesMut};
|
||||
use futures::{io::IoSlice, prelude::*};
|
||||
use std::{
|
||||
convert::TryFrom as _,
|
||||
io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
u16,
|
||||
};
|
||||
|
||||
const MAX_LEN_BYTES: u16 = 2;
|
||||
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
|
||||
@ -50,7 +56,10 @@ pub struct LengthDelimited<R> {
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
enum ReadState {
|
||||
/// We are currently reading the length of the next frame of data.
|
||||
ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize },
|
||||
ReadLength {
|
||||
buf: [u8; MAX_LEN_BYTES as usize],
|
||||
pos: usize,
|
||||
},
|
||||
/// We are currently reading the frame of data itself.
|
||||
ReadData { len: u16, pos: usize },
|
||||
}
|
||||
@ -59,7 +68,7 @@ impl Default for ReadState {
|
||||
fn default() -> Self {
|
||||
ReadState::ReadLength {
|
||||
buf: [0; MAX_LEN_BYTES as usize],
|
||||
pos: 0
|
||||
pos: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -106,10 +115,12 @@ impl<R> LengthDelimited<R> {
|
||||
///
|
||||
/// After this method returns `Poll::Ready`, the write buffer of frames
|
||||
/// submitted to the `Sink` is guaranteed to be empty.
|
||||
pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>)
|
||||
-> Poll<Result<(), io::Error>>
|
||||
pub fn poll_write_buffer(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<(), io::Error>>
|
||||
where
|
||||
R: AsyncWrite
|
||||
R: AsyncWrite,
|
||||
{
|
||||
let mut this = self.project();
|
||||
|
||||
@ -119,7 +130,8 @@ impl<R> LengthDelimited<R> {
|
||||
Poll::Ready(Ok(0)) => {
|
||||
return Poll::Ready(Err(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"Failed to write buffered frame.")))
|
||||
"Failed to write buffered frame.",
|
||||
)))
|
||||
}
|
||||
Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
@ -132,7 +144,7 @@ impl<R> LengthDelimited<R> {
|
||||
|
||||
impl<R> Stream for LengthDelimited<R>
|
||||
where
|
||||
R: AsyncRead
|
||||
R: AsyncRead,
|
||||
{
|
||||
type Item = Result<Bytes, io::Error>;
|
||||
|
||||
@ -142,7 +154,7 @@ where
|
||||
loop {
|
||||
match this.read_state {
|
||||
ReadState::ReadLength { buf, pos } => {
|
||||
match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) {
|
||||
match this.inner.as_mut().poll_read(cx, &mut buf[*pos..*pos + 1]) {
|
||||
Poll::Ready(Ok(0)) => {
|
||||
if *pos == 0 {
|
||||
return Poll::Ready(None);
|
||||
@ -160,11 +172,10 @@ where
|
||||
|
||||
if (buf[*pos - 1] & 0x80) == 0 {
|
||||
// MSB is not set, indicating the end of the length prefix.
|
||||
let (len, _) = unsigned_varint::decode::u16(buf)
|
||||
.map_err(|e| {
|
||||
log::debug!("invalid length prefix: {}", e);
|
||||
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
|
||||
})?;
|
||||
let (len, _) = unsigned_varint::decode::u16(buf).map_err(|e| {
|
||||
log::debug!("invalid length prefix: {}", e);
|
||||
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
|
||||
})?;
|
||||
|
||||
if len >= 1 {
|
||||
*this.read_state = ReadState::ReadData { len, pos: 0 };
|
||||
@ -179,12 +190,19 @@ where
|
||||
// See the module documentation about the max frame len.
|
||||
return Poll::Ready(Some(Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Maximum frame length exceeded"))));
|
||||
"Maximum frame length exceeded",
|
||||
))));
|
||||
}
|
||||
}
|
||||
ReadState::ReadData { len, pos } => {
|
||||
match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
|
||||
Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
|
||||
match this
|
||||
.inner
|
||||
.as_mut()
|
||||
.poll_read(cx, &mut this.read_buffer[*pos..])
|
||||
{
|
||||
Poll::Ready(Ok(0)) => {
|
||||
return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())))
|
||||
}
|
||||
Poll::Ready(Ok(n)) => *pos += n,
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
|
||||
@ -214,7 +232,7 @@ where
|
||||
// implied to be roughly 2 * MAX_FRAME_SIZE.
|
||||
if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
|
||||
match self.as_mut().poll_write_buffer(cx) {
|
||||
Poll::Ready(Ok(())) => {},
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
@ -233,7 +251,8 @@ where
|
||||
_ => {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Maximum frame size exceeded."))
|
||||
"Maximum frame size exceeded.",
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
@ -249,7 +268,7 @@ where
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// Write all buffered frame data to the underlying I/O stream.
|
||||
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
|
||||
Poll::Ready(Ok(())) => {},
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
@ -264,7 +283,7 @@ where
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// Write all buffered frame data to the underlying I/O stream.
|
||||
match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
|
||||
Poll::Ready(Ok(())) => {},
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
@ -283,7 +302,7 @@ where
|
||||
#[derive(Debug)]
|
||||
pub struct LengthDelimitedReader<R> {
|
||||
#[pin]
|
||||
inner: LengthDelimited<R>
|
||||
inner: LengthDelimited<R>,
|
||||
}
|
||||
|
||||
impl<R> LengthDelimitedReader<R> {
|
||||
@ -306,7 +325,7 @@ impl<R> LengthDelimitedReader<R> {
|
||||
|
||||
impl<R> Stream for LengthDelimitedReader<R>
|
||||
where
|
||||
R: AsyncRead
|
||||
R: AsyncRead,
|
||||
{
|
||||
type Item = Result<Bytes, io::Error>;
|
||||
|
||||
@ -317,17 +336,19 @@ where
|
||||
|
||||
impl<R> AsyncWrite for LengthDelimitedReader<R>
|
||||
where
|
||||
R: AsyncWrite
|
||||
R: AsyncWrite,
|
||||
{
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8])
|
||||
-> Poll<Result<usize, io::Error>>
|
||||
{
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
// `this` here designates the `LengthDelimited`.
|
||||
let mut this = self.project().inner;
|
||||
|
||||
// We need to flush any data previously written with the `LengthDelimited`.
|
||||
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
|
||||
Poll::Ready(Ok(())) => {},
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
@ -344,15 +365,17 @@ where
|
||||
self.project().inner.poll_close(cx)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
|
||||
-> Poll<Result<usize, io::Error>>
|
||||
{
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
// `this` here designates the `LengthDelimited`.
|
||||
let mut this = self.project().inner;
|
||||
|
||||
// We need to flush any data previously written with the `LengthDelimited`.
|
||||
match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
|
||||
Poll::Ready(Ok(())) => {},
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
@ -366,7 +389,7 @@ where
|
||||
mod tests {
|
||||
use crate::length_delimited::LengthDelimited;
|
||||
use async_std::net::{TcpListener, TcpStream};
|
||||
use futures::{prelude::*, io::Cursor};
|
||||
use futures::{io::Cursor, prelude::*};
|
||||
use quickcheck::*;
|
||||
use std::io::ErrorKind;
|
||||
|
||||
@ -394,9 +417,7 @@ mod tests {
|
||||
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
|
||||
data.extend(frame.clone().into_iter());
|
||||
let mut framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = futures::executor::block_on(async move {
|
||||
framed.next().await
|
||||
}).unwrap();
|
||||
let recved = futures::executor::block_on(async move { framed.next().await }).unwrap();
|
||||
assert_eq!(recved.unwrap(), frame);
|
||||
}
|
||||
|
||||
@ -405,9 +426,7 @@ mod tests {
|
||||
let mut data = vec![0x81, 0x81, 0x1];
|
||||
data.extend((0..16513).map(|_| 0));
|
||||
let mut framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = futures::executor::block_on(async move {
|
||||
framed.next().await.unwrap()
|
||||
});
|
||||
let recved = futures::executor::block_on(async move { framed.next().await.unwrap() });
|
||||
|
||||
if let Err(io_err) = recved {
|
||||
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
|
||||
@ -479,7 +498,8 @@ mod tests {
|
||||
let expected_frames = frames.clone();
|
||||
let server = async_std::task::spawn(async move {
|
||||
let socket = listener.accept().await.unwrap().0;
|
||||
let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
|
||||
let mut connec =
|
||||
rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
|
||||
|
||||
let mut buf = vec![0u8; 0];
|
||||
for expected in expected_frames {
|
||||
|
@ -94,10 +94,10 @@ mod negotiated;
|
||||
mod protocol;
|
||||
mod tests;
|
||||
|
||||
pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError};
|
||||
pub use self::protocol::ProtocolError;
|
||||
pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture};
|
||||
pub use self::listener_select::{listener_select_proto, ListenerSelectFuture};
|
||||
pub use self::negotiated::{Negotiated, NegotiatedComplete, NegotiationError};
|
||||
pub use self::protocol::ProtocolError;
|
||||
|
||||
/// Supported multistream-select versions.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@ -145,4 +145,4 @@ impl Default for Version {
|
||||
fn default() -> Self {
|
||||
Version::V1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,12 +21,18 @@
|
||||
//! Protocol negotiation strategies for the peer acting as the listener
|
||||
//! in a multistream-select protocol negotiation.
|
||||
|
||||
use crate::protocol::{HeaderLine, Message, MessageIO, Protocol, ProtocolError};
|
||||
use crate::{Negotiated, NegotiationError};
|
||||
use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, HeaderLine};
|
||||
|
||||
use futures::prelude::*;
|
||||
use smallvec::SmallVec;
|
||||
use std::{convert::TryFrom as _, iter::FromIterator, mem, pin::Pin, task::{Context, Poll}};
|
||||
use std::{
|
||||
convert::TryFrom as _,
|
||||
iter::FromIterator,
|
||||
mem,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// Returns a `Future` that negotiates a protocol on the given I/O stream
|
||||
/// for a peer acting as the _listener_ (or _responder_).
|
||||
@ -35,28 +41,29 @@ use std::{convert::TryFrom as _, iter::FromIterator, mem, pin::Pin, task::{Conte
|
||||
/// computation that performs the protocol negotiation with the remote. The
|
||||
/// returned `Future` resolves with the name of the negotiated protocol and
|
||||
/// a [`Negotiated`] I/O stream.
|
||||
pub fn listener_select_proto<R, I>(
|
||||
inner: R,
|
||||
protocols: I,
|
||||
) -> ListenerSelectFuture<R, I::Item>
|
||||
pub fn listener_select_proto<R, I>(inner: R, protocols: I) -> ListenerSelectFuture<R, I::Item>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: IntoIterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
I::Item: AsRef<[u8]>,
|
||||
{
|
||||
let protocols = protocols.into_iter().filter_map(|n|
|
||||
match Protocol::try_from(n.as_ref()) {
|
||||
let protocols = protocols
|
||||
.into_iter()
|
||||
.filter_map(|n| match Protocol::try_from(n.as_ref()) {
|
||||
Ok(p) => Some((n, p)),
|
||||
Err(e) => {
|
||||
log::warn!("Listener: Ignoring invalid protocol: {} due to {}",
|
||||
String::from_utf8_lossy(n.as_ref()), e);
|
||||
log::warn!(
|
||||
"Listener: Ignoring invalid protocol: {} due to {}",
|
||||
String::from_utf8_lossy(n.as_ref()),
|
||||
e
|
||||
);
|
||||
None
|
||||
}
|
||||
});
|
||||
ListenerSelectFuture {
|
||||
protocols: SmallVec::from_iter(protocols),
|
||||
state: State::RecvHeader {
|
||||
io: MessageIO::new(inner)
|
||||
io: MessageIO::new(inner),
|
||||
},
|
||||
last_sent_na: false,
|
||||
}
|
||||
@ -80,19 +87,25 @@ pub struct ListenerSelectFuture<R, N> {
|
||||
}
|
||||
|
||||
enum State<R, N> {
|
||||
RecvHeader { io: MessageIO<R> },
|
||||
SendHeader { io: MessageIO<R> },
|
||||
RecvMessage { io: MessageIO<R> },
|
||||
RecvHeader {
|
||||
io: MessageIO<R>,
|
||||
},
|
||||
SendHeader {
|
||||
io: MessageIO<R>,
|
||||
},
|
||||
RecvMessage {
|
||||
io: MessageIO<R>,
|
||||
},
|
||||
SendMessage {
|
||||
io: MessageIO<R>,
|
||||
message: Message,
|
||||
protocol: Option<N>
|
||||
protocol: Option<N>,
|
||||
},
|
||||
Flush {
|
||||
io: MessageIO<R>,
|
||||
protocol: Option<N>
|
||||
protocol: Option<N>,
|
||||
},
|
||||
Done
|
||||
Done,
|
||||
}
|
||||
|
||||
impl<R, N> Future for ListenerSelectFuture<R, N>
|
||||
@ -100,7 +113,7 @@ where
|
||||
// The Unpin bound here is required because we produce a `Negotiated<R>` as the output.
|
||||
// It also makes the implementation considerably easier to write.
|
||||
R: AsyncRead + AsyncWrite + Unpin,
|
||||
N: AsRef<[u8]> + Clone
|
||||
N: AsRef<[u8]> + Clone,
|
||||
{
|
||||
type Output = Result<(N, Negotiated<R>), NegotiationError>;
|
||||
|
||||
@ -111,14 +124,12 @@ where
|
||||
match mem::replace(this.state, State::Done) {
|
||||
State::RecvHeader { mut io } => {
|
||||
match io.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(Ok(Message::Header(h)))) => {
|
||||
match h {
|
||||
HeaderLine::V1 => *this.state = State::SendHeader { io }
|
||||
}
|
||||
}
|
||||
Poll::Ready(Some(Ok(Message::Header(h)))) => match h {
|
||||
HeaderLine::V1 => *this.state = State::SendHeader { io },
|
||||
},
|
||||
Poll::Ready(Some(Ok(_))) => {
|
||||
return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
|
||||
},
|
||||
}
|
||||
Poll::Ready(Some(Err(err))) => return Poll::Ready(Err(From::from(err))),
|
||||
// Treat EOF error as [`NegotiationError::Failed`], not as
|
||||
// [`NegotiationError::ProtocolError`], allowing dropping or closing an I/O
|
||||
@ -126,7 +137,7 @@ where
|
||||
Poll::Ready(None) => return Poll::Ready(Err(NegotiationError::Failed)),
|
||||
Poll::Pending => {
|
||||
*this.state = State::RecvHeader { io };
|
||||
return Poll::Pending
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -135,9 +146,9 @@ where
|
||||
match Pin::new(&mut io).poll_ready(cx) {
|
||||
Poll::Pending => {
|
||||
*this.state = State::SendHeader { io };
|
||||
return Poll::Pending
|
||||
},
|
||||
Poll::Ready(Ok(())) => {},
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||
}
|
||||
|
||||
@ -175,28 +186,37 @@ where
|
||||
// the dialer also raises `NegotiationError::Failed` when finally
|
||||
// reading the `N/A` response.
|
||||
if let ProtocolError::InvalidMessage = &err {
|
||||
log::trace!("Listener: Negotiation failed with invalid \
|
||||
message after protocol rejection.");
|
||||
return Poll::Ready(Err(NegotiationError::Failed))
|
||||
log::trace!(
|
||||
"Listener: Negotiation failed with invalid \
|
||||
message after protocol rejection."
|
||||
);
|
||||
return Poll::Ready(Err(NegotiationError::Failed));
|
||||
}
|
||||
if let ProtocolError::IoError(e) = &err {
|
||||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||||
log::trace!("Listener: Negotiation failed with EOF \
|
||||
after protocol rejection.");
|
||||
return Poll::Ready(Err(NegotiationError::Failed))
|
||||
log::trace!(
|
||||
"Listener: Negotiation failed with EOF \
|
||||
after protocol rejection."
|
||||
);
|
||||
return Poll::Ready(Err(NegotiationError::Failed));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Poll::Ready(Err(From::from(err)))
|
||||
return Poll::Ready(Err(From::from(err)));
|
||||
}
|
||||
};
|
||||
|
||||
match msg {
|
||||
Message::ListProtocols => {
|
||||
let supported = this.protocols.iter().map(|(_,p)| p).cloned().collect();
|
||||
let supported =
|
||||
this.protocols.iter().map(|(_, p)| p).cloned().collect();
|
||||
let message = Message::Protocols(supported);
|
||||
*this.state = State::SendMessage { io, message, protocol: None }
|
||||
*this.state = State::SendMessage {
|
||||
io,
|
||||
message,
|
||||
protocol: None,
|
||||
}
|
||||
}
|
||||
Message::Protocol(p) => {
|
||||
let protocol = this.protocols.iter().find_map(|(name, proto)| {
|
||||
@ -211,28 +231,42 @@ where
|
||||
log::debug!("Listener: confirming protocol: {}", p);
|
||||
Message::Protocol(p.clone())
|
||||
} else {
|
||||
log::debug!("Listener: rejecting protocol: {}",
|
||||
String::from_utf8_lossy(p.as_ref()));
|
||||
log::debug!(
|
||||
"Listener: rejecting protocol: {}",
|
||||
String::from_utf8_lossy(p.as_ref())
|
||||
);
|
||||
Message::NotAvailable
|
||||
};
|
||||
|
||||
*this.state = State::SendMessage { io, message, protocol };
|
||||
*this.state = State::SendMessage {
|
||||
io,
|
||||
message,
|
||||
protocol,
|
||||
};
|
||||
}
|
||||
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into()))
|
||||
_ => return Poll::Ready(Err(ProtocolError::InvalidMessage.into())),
|
||||
}
|
||||
}
|
||||
|
||||
State::SendMessage { mut io, message, protocol } => {
|
||||
State::SendMessage {
|
||||
mut io,
|
||||
message,
|
||||
protocol,
|
||||
} => {
|
||||
match Pin::new(&mut io).poll_ready(cx) {
|
||||
Poll::Pending => {
|
||||
*this.state = State::SendMessage { io, message, protocol };
|
||||
return Poll::Pending
|
||||
},
|
||||
Poll::Ready(Ok(())) => {},
|
||||
*this.state = State::SendMessage {
|
||||
io,
|
||||
message,
|
||||
protocol,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||
}
|
||||
|
||||
if let Message::NotAvailable = &message {
|
||||
if let Message::NotAvailable = &message {
|
||||
*this.last_sent_na = true;
|
||||
} else {
|
||||
*this.last_sent_na = false;
|
||||
@ -249,26 +283,28 @@ where
|
||||
match Pin::new(&mut io).poll_flush(cx) {
|
||||
Poll::Pending => {
|
||||
*this.state = State::Flush { io, protocol };
|
||||
return Poll::Pending
|
||||
},
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(Ok(())) => {
|
||||
// If a protocol has been selected, finish negotiation.
|
||||
// Otherwise expect to receive another message.
|
||||
match protocol {
|
||||
Some(protocol) => {
|
||||
log::debug!("Listener: sent confirmed protocol: {}",
|
||||
String::from_utf8_lossy(protocol.as_ref()));
|
||||
log::debug!(
|
||||
"Listener: sent confirmed protocol: {}",
|
||||
String::from_utf8_lossy(protocol.as_ref())
|
||||
);
|
||||
let io = Negotiated::completed(io.into_inner());
|
||||
return Poll::Ready(Ok((protocol, io)))
|
||||
return Poll::Ready(Ok((protocol, io)));
|
||||
}
|
||||
None => *this.state = State::RecvMessage { io }
|
||||
None => *this.state = State::RecvMessage { io },
|
||||
}
|
||||
}
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||
}
|
||||
}
|
||||
|
||||
State::Done => panic!("State::poll called after completion")
|
||||
State::Done => panic!("State::poll called after completion"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -18,11 +18,20 @@
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use crate::protocol::{Protocol, MessageReader, Message, ProtocolError, HeaderLine};
|
||||
use crate::protocol::{HeaderLine, Message, MessageReader, Protocol, ProtocolError};
|
||||
|
||||
use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
|
||||
use futures::{
|
||||
io::{IoSlice, IoSliceMut},
|
||||
prelude::*,
|
||||
ready,
|
||||
};
|
||||
use pin_project::pin_project;
|
||||
use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
|
||||
use std::{
|
||||
error::Error,
|
||||
fmt, io, mem,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// An I/O stream that has settled on an (application-layer) protocol to use.
|
||||
///
|
||||
@ -39,7 +48,7 @@ use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
|
||||
#[derive(Debug)]
|
||||
pub struct Negotiated<TInner> {
|
||||
#[pin]
|
||||
state: State<TInner>
|
||||
state: State<TInner>,
|
||||
}
|
||||
|
||||
/// A `Future` that waits on the completion of protocol negotiation.
|
||||
@ -57,12 +66,15 @@ where
|
||||
type Output = Result<Negotiated<TInner>, NegotiationError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
|
||||
let mut io = self
|
||||
.inner
|
||||
.take()
|
||||
.expect("NegotiatedFuture called after completion.");
|
||||
match Negotiated::poll(Pin::new(&mut io), cx) {
|
||||
Poll::Pending => {
|
||||
self.inner = Some(io);
|
||||
Poll::Pending
|
||||
},
|
||||
}
|
||||
Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
|
||||
Poll::Ready(Err(err)) => {
|
||||
self.inner = Some(io);
|
||||
@ -75,7 +87,9 @@ where
|
||||
impl<TInner> Negotiated<TInner> {
|
||||
/// Creates a `Negotiated` in state [`State::Completed`].
|
||||
pub(crate) fn completed(io: TInner) -> Self {
|
||||
Negotiated { state: State::Completed { io } }
|
||||
Negotiated {
|
||||
state: State::Completed { io },
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a `Negotiated` in state [`State::Expecting`] that is still
|
||||
@ -83,25 +97,31 @@ impl<TInner> Negotiated<TInner> {
|
||||
pub(crate) fn expecting(
|
||||
io: MessageReader<TInner>,
|
||||
protocol: Protocol,
|
||||
header: Option<HeaderLine>
|
||||
header: Option<HeaderLine>,
|
||||
) -> Self {
|
||||
Negotiated { state: State::Expecting { io, protocol, header } }
|
||||
Negotiated {
|
||||
state: State::Expecting {
|
||||
io,
|
||||
protocol,
|
||||
header,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Polls the `Negotiated` for completion.
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), NegotiationError>>
|
||||
where
|
||||
TInner: AsyncRead + AsyncWrite + Unpin
|
||||
TInner: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
// Flush any pending negotiation data.
|
||||
match self.as_mut().poll_flush(cx) {
|
||||
Poll::Ready(Ok(())) => {},
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(e)) => {
|
||||
// If the remote closed the stream, it is important to still
|
||||
// continue reading the data that was sent, if any.
|
||||
if e.kind() != io::ErrorKind::WriteZero {
|
||||
return Poll::Ready(Err(e.into()))
|
||||
return Poll::Ready(Err(e.into()));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -109,36 +129,52 @@ impl<TInner> Negotiated<TInner> {
|
||||
let mut this = self.project();
|
||||
|
||||
if let StateProj::Completed { .. } = this.state.as_mut().project() {
|
||||
return Poll::Ready(Ok(()));
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// Read outstanding protocol negotiation messages.
|
||||
loop {
|
||||
match mem::replace(&mut *this.state, State::Invalid) {
|
||||
State::Expecting { mut io, header, protocol } => {
|
||||
State::Expecting {
|
||||
mut io,
|
||||
header,
|
||||
protocol,
|
||||
} => {
|
||||
let msg = match Pin::new(&mut io).poll_next(cx)? {
|
||||
Poll::Ready(Some(msg)) => msg,
|
||||
Poll::Pending => {
|
||||
*this.state = State::Expecting { io, header, protocol };
|
||||
return Poll::Pending
|
||||
},
|
||||
*this.state = State::Expecting {
|
||||
io,
|
||||
header,
|
||||
protocol,
|
||||
};
|
||||
return Poll::Pending;
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
return Poll::Ready(Err(ProtocolError::IoError(
|
||||
io::ErrorKind::UnexpectedEof.into()).into()));
|
||||
io::ErrorKind::UnexpectedEof.into(),
|
||||
)
|
||||
.into()));
|
||||
}
|
||||
};
|
||||
|
||||
if let Message::Header(h) = &msg {
|
||||
if Some(h) == header.as_ref() {
|
||||
*this.state = State::Expecting { io, protocol, header: None };
|
||||
continue
|
||||
*this.state = State::Expecting {
|
||||
io,
|
||||
protocol,
|
||||
header: None,
|
||||
};
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if let Message::Protocol(p) = &msg {
|
||||
if p.as_ref() == protocol.as_ref() {
|
||||
log::debug!("Negotiated: Received confirmation for protocol: {}", p);
|
||||
*this.state = State::Completed { io: io.into_inner() };
|
||||
*this.state = State::Completed {
|
||||
io: io.into_inner(),
|
||||
};
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
}
|
||||
@ -146,7 +182,7 @@ impl<TInner> Negotiated<TInner> {
|
||||
return Poll::Ready(Err(NegotiationError::Failed));
|
||||
}
|
||||
|
||||
_ => panic!("Negotiated: Invalid state")
|
||||
_ => panic!("Negotiated: Invalid state"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -178,7 +214,10 @@ enum State<R> {
|
||||
|
||||
/// In this state, a protocol has been agreed upon and I/O
|
||||
/// on the underlying stream can commence.
|
||||
Completed { #[pin] io: R },
|
||||
Completed {
|
||||
#[pin]
|
||||
io: R,
|
||||
},
|
||||
|
||||
/// Temporary state while moving the `io` resource from
|
||||
/// `Expecting` to `Completed`.
|
||||
@ -187,11 +226,13 @@ enum State<R> {
|
||||
|
||||
impl<TInner> AsyncRead for Negotiated<TInner>
|
||||
where
|
||||
TInner: AsyncRead + AsyncWrite + Unpin
|
||||
TInner: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8])
|
||||
-> Poll<Result<usize, io::Error>>
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
loop {
|
||||
if let StateProj::Completed { io } = self.as_mut().project().state.project() {
|
||||
// If protocol negotiation is complete, commence with reading.
|
||||
@ -201,7 +242,7 @@ where
|
||||
// Poll the `Negotiated`, driving protocol negotiation to completion,
|
||||
// including flushing of any remaining data.
|
||||
match self.as_mut().poll(cx) {
|
||||
Poll::Ready(Ok(())) => {},
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||
}
|
||||
@ -217,19 +258,21 @@ where
|
||||
}
|
||||
}*/
|
||||
|
||||
fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>])
|
||||
-> Poll<Result<usize, io::Error>>
|
||||
{
|
||||
fn poll_read_vectored(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &mut [IoSliceMut<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
loop {
|
||||
if let StateProj::Completed { io } = self.as_mut().project().state.project() {
|
||||
// If protocol negotiation is complete, commence with reading.
|
||||
return io.poll_read_vectored(cx, bufs)
|
||||
return io.poll_read_vectored(cx, bufs);
|
||||
}
|
||||
|
||||
// Poll the `Negotiated`, driving protocol negotiation to completion,
|
||||
// including flushing of any remaining data.
|
||||
match self.as_mut().poll(cx) {
|
||||
Poll::Ready(Ok(())) => {},
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
|
||||
}
|
||||
@ -239,9 +282,13 @@ where
|
||||
|
||||
impl<TInner> AsyncWrite for Negotiated<TInner>
|
||||
where
|
||||
TInner: AsyncWrite + AsyncRead + Unpin
|
||||
TInner: AsyncWrite + AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
match self.project().state.project() {
|
||||
StateProj::Completed { io } => io.poll_write(cx, buf),
|
||||
StateProj::Expecting { io, .. } => io.poll_write(cx, buf),
|
||||
@ -261,7 +308,10 @@ where
|
||||
// Ensure all data has been flushed and expected negotiation messages
|
||||
// have been received.
|
||||
ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
|
||||
ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
|
||||
ready!(self
|
||||
.as_mut()
|
||||
.poll_flush(cx)
|
||||
.map_err(Into::<io::Error>::into)?);
|
||||
|
||||
// Continue with the shutdown of the underlying I/O stream.
|
||||
match self.project().state.project() {
|
||||
@ -271,9 +321,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
|
||||
-> Poll<Result<usize, io::Error>>
|
||||
{
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
match self.project().state.project() {
|
||||
StateProj::Completed { io } => io.poll_write_vectored(cx, bufs),
|
||||
StateProj::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
|
||||
@ -307,7 +359,7 @@ impl From<io::Error> for NegotiationError {
|
||||
impl From<NegotiationError> for io::Error {
|
||||
fn from(err: NegotiationError) -> io::Error {
|
||||
if let NegotiationError::ProtocolError(e) = err {
|
||||
return e.into()
|
||||
return e.into();
|
||||
}
|
||||
io::Error::new(io::ErrorKind::Other, err)
|
||||
}
|
||||
@ -325,10 +377,10 @@ impl Error for NegotiationError {
|
||||
impl fmt::Display for NegotiationError {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||
match self {
|
||||
NegotiationError::ProtocolError(p) =>
|
||||
fmt.write_fmt(format_args!("Protocol error: {}", p)),
|
||||
NegotiationError::Failed =>
|
||||
fmt.write_str("Protocol negotiation failed.")
|
||||
NegotiationError::ProtocolError(p) => {
|
||||
fmt.write_fmt(format_args!("Protocol error: {}", p))
|
||||
}
|
||||
NegotiationError::Failed => fmt.write_str("Protocol negotiation failed."),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -25,12 +25,18 @@
|
||||
//! `Stream` and `Sink` implementations of `MessageIO` and
|
||||
//! `MessageReader`.
|
||||
|
||||
use crate::Version;
|
||||
use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
|
||||
use crate::Version;
|
||||
|
||||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
use futures::{prelude::*, io::IoSlice, ready};
|
||||
use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use futures::{io::IoSlice, prelude::*, ready};
|
||||
use std::{
|
||||
convert::TryFrom,
|
||||
error::Error,
|
||||
fmt, io,
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
use unsigned_varint as uvi;
|
||||
|
||||
/// The maximum number of supported protocols that can be processed.
|
||||
@ -75,7 +81,7 @@ impl TryFrom<Bytes> for Protocol {
|
||||
|
||||
fn try_from(value: Bytes) -> Result<Self, Self::Error> {
|
||||
if !value.as_ref().starts_with(b"/") {
|
||||
return Err(ProtocolError::InvalidProtocol)
|
||||
return Err(ProtocolError::InvalidProtocol);
|
||||
}
|
||||
Ok(Protocol(value))
|
||||
}
|
||||
@ -160,7 +166,7 @@ impl Message {
|
||||
/// Decodes a `Message` from its byte representation.
|
||||
pub fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
|
||||
if msg == MSG_MULTISTREAM_1_0 {
|
||||
return Ok(Message::Header(HeaderLine::V1))
|
||||
return Ok(Message::Header(HeaderLine::V1));
|
||||
}
|
||||
|
||||
if msg == MSG_PROTOCOL_NA {
|
||||
@ -168,13 +174,14 @@ impl Message {
|
||||
}
|
||||
|
||||
if msg == MSG_LS {
|
||||
return Ok(Message::ListProtocols)
|
||||
return Ok(Message::ListProtocols);
|
||||
}
|
||||
|
||||
// If it starts with a `/`, ends with a line feed without any
|
||||
// other line feeds in-between, it must be a protocol name.
|
||||
if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') &&
|
||||
!msg[.. msg.len() - 1].contains(&b'\n')
|
||||
if msg.get(0) == Some(&b'/')
|
||||
&& msg.last() == Some(&b'\n')
|
||||
&& !msg[..msg.len() - 1].contains(&b'\n')
|
||||
{
|
||||
let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
|
||||
return Ok(Message::Protocol(p));
|
||||
@ -187,24 +194,24 @@ impl Message {
|
||||
loop {
|
||||
// A well-formed message must be terminated with a newline.
|
||||
if remaining == [b'\n'] {
|
||||
break
|
||||
break;
|
||||
} else if protocols.len() == MAX_PROTOCOLS {
|
||||
return Err(ProtocolError::TooManyProtocols)
|
||||
return Err(ProtocolError::TooManyProtocols);
|
||||
}
|
||||
|
||||
// Decode the length of the next protocol name and check that
|
||||
// it ends with a line feed.
|
||||
let (len, tail) = uvi::decode::usize(remaining)?;
|
||||
if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
|
||||
return Err(ProtocolError::InvalidMessage)
|
||||
return Err(ProtocolError::InvalidMessage);
|
||||
}
|
||||
|
||||
// Parse the protocol name.
|
||||
let p = Protocol::try_from(Bytes::copy_from_slice(&tail[.. len - 1]))?;
|
||||
let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?;
|
||||
protocols.push(p);
|
||||
|
||||
// Skip ahead to the next protocol.
|
||||
remaining = &tail[len ..];
|
||||
remaining = &tail[len..];
|
||||
}
|
||||
|
||||
Ok(Message::Protocols(protocols))
|
||||
@ -222,9 +229,11 @@ impl<R> MessageIO<R> {
|
||||
/// Constructs a new `MessageIO` resource wrapping the given I/O stream.
|
||||
pub fn new(inner: R) -> MessageIO<R>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite
|
||||
R: AsyncRead + AsyncWrite,
|
||||
{
|
||||
Self { inner: LengthDelimited::new(inner) }
|
||||
Self {
|
||||
inner: LengthDelimited::new(inner),
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the [`MessageIO`] into a [`MessageReader`], dropping the
|
||||
@ -235,7 +244,9 @@ impl<R> MessageIO<R> {
|
||||
/// received but no more messages are written, allowing the writing of
|
||||
/// follow-up protocol data to commence.
|
||||
pub fn into_reader(self) -> MessageReader<R> {
|
||||
MessageReader { inner: self.inner.into_reader() }
|
||||
MessageReader {
|
||||
inner: self.inner.into_reader(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Drops the [`MessageIO`] resource, yielding the underlying I/O stream.
|
||||
@ -265,7 +276,10 @@ where
|
||||
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
let mut buf = BytesMut::new();
|
||||
item.encode(&mut buf)?;
|
||||
self.project().inner.start_send(buf.freeze()).map_err(From::from)
|
||||
self.project()
|
||||
.inner
|
||||
.start_send(buf.freeze())
|
||||
.map_err(From::from)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
@ -279,7 +293,7 @@ where
|
||||
|
||||
impl<R> Stream for MessageIO<R>
|
||||
where
|
||||
R: AsyncRead
|
||||
R: AsyncRead,
|
||||
{
|
||||
type Item = Result<Message, ProtocolError>;
|
||||
|
||||
@ -299,7 +313,7 @@ where
|
||||
#[derive(Debug)]
|
||||
pub struct MessageReader<R> {
|
||||
#[pin]
|
||||
inner: LengthDelimitedReader<R>
|
||||
inner: LengthDelimitedReader<R>,
|
||||
}
|
||||
|
||||
impl<R> MessageReader<R> {
|
||||
@ -321,7 +335,7 @@ impl<R> MessageReader<R> {
|
||||
|
||||
impl<R> Stream for MessageReader<R>
|
||||
where
|
||||
R: AsyncRead
|
||||
R: AsyncRead,
|
||||
{
|
||||
type Item = Result<Message, ProtocolError>;
|
||||
|
||||
@ -332,9 +346,13 @@ where
|
||||
|
||||
impl<TInner> AsyncWrite for MessageReader<TInner>
|
||||
where
|
||||
TInner: AsyncWrite
|
||||
TInner: AsyncWrite,
|
||||
{
|
||||
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write(cx, buf)
|
||||
}
|
||||
|
||||
@ -346,12 +364,19 @@ where
|
||||
self.project().inner.poll_close(cx)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize, io::Error>> {
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[IoSlice<'_>],
|
||||
) -> Poll<Result<usize, io::Error>> {
|
||||
self.project().inner.poll_write_vectored(cx, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_stream<S>(stream: Pin<&mut S>, cx: &mut Context<'_>) -> Poll<Option<Result<Message, ProtocolError>>>
|
||||
fn poll_stream<S>(
|
||||
stream: Pin<&mut S>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Option<Result<Message, ProtocolError>>>
|
||||
where
|
||||
S: Stream<Item = Result<Bytes, io::Error>>,
|
||||
{
|
||||
@ -361,7 +386,7 @@ where
|
||||
Err(err) => return Poll::Ready(Some(Err(err))),
|
||||
}
|
||||
} else {
|
||||
return Poll::Ready(None)
|
||||
return Poll::Ready(None);
|
||||
};
|
||||
|
||||
log::trace!("Received message: {:?}", msg);
|
||||
@ -394,7 +419,7 @@ impl From<io::Error> for ProtocolError {
|
||||
impl From<ProtocolError> for io::Error {
|
||||
fn from(err: ProtocolError) -> Self {
|
||||
if let ProtocolError::IoError(e) = err {
|
||||
return e
|
||||
return e;
|
||||
}
|
||||
io::ErrorKind::InvalidData.into()
|
||||
}
|
||||
@ -418,14 +443,10 @@ impl Error for ProtocolError {
|
||||
impl fmt::Display for ProtocolError {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||
match self {
|
||||
ProtocolError::IoError(e) =>
|
||||
write!(fmt, "I/O error: {}", e),
|
||||
ProtocolError::InvalidMessage =>
|
||||
write!(fmt, "Received an invalid message."),
|
||||
ProtocolError::InvalidProtocol =>
|
||||
write!(fmt, "A protocol (name) is invalid."),
|
||||
ProtocolError::TooManyProtocols =>
|
||||
write!(fmt, "Too many protocols received.")
|
||||
ProtocolError::IoError(e) => write!(fmt, "I/O error: {}", e),
|
||||
ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."),
|
||||
ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."),
|
||||
ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -434,8 +455,8 @@ impl fmt::Display for ProtocolError {
|
||||
mod tests {
|
||||
use super::*;
|
||||
use quickcheck::*;
|
||||
use rand::Rng;
|
||||
use rand::distributions::Alphanumeric;
|
||||
use rand::Rng;
|
||||
use std::iter;
|
||||
|
||||
impl Arbitrary for Protocol {
|
||||
@ -457,7 +478,7 @@ mod tests {
|
||||
2 => Message::ListProtocols,
|
||||
3 => Message::Protocol(Protocol::arbitrary(g)),
|
||||
4 => Message::Protocols(Vec::arbitrary(g)),
|
||||
_ => panic!()
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -466,10 +487,11 @@ mod tests {
|
||||
fn encode_decode_message() {
|
||||
fn prop(msg: Message) {
|
||||
let mut buf = BytesMut::new();
|
||||
msg.encode(&mut buf).expect(&format!("Encoding message failed: {:?}", msg));
|
||||
msg.encode(&mut buf)
|
||||
.expect(&format!("Encoding message failed: {:?}", msg));
|
||||
match Message::decode(buf.freeze()) {
|
||||
Ok(m) => assert_eq!(m, msg),
|
||||
Err(e) => panic!("Decoding failed: {:?}", e)
|
||||
Err(e) => panic!("Decoding failed: {:?}", e),
|
||||
}
|
||||
}
|
||||
quickcheck(prop as fn(_))
|
||||
|
@ -22,9 +22,9 @@
|
||||
|
||||
#![cfg(test)]
|
||||
|
||||
use crate::{Version, NegotiationError};
|
||||
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
|
||||
use crate::{dialer_select_proto, listener_select_proto};
|
||||
use crate::{NegotiationError, Version};
|
||||
|
||||
use async_std::net::{TcpListener, TcpStream};
|
||||
use futures::prelude::*;
|
||||
@ -54,7 +54,8 @@ fn select_proto_basic() {
|
||||
let connec = TcpStream::connect(&listener_addr).await.unwrap();
|
||||
let protos = vec![b"/proto3", b"/proto2"];
|
||||
let (proto, mut io) = dialer_select_proto(connec, protos.into_iter(), version)
|
||||
.await.unwrap();
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(proto, b"/proto2");
|
||||
|
||||
io.write_all(b"ping").await.unwrap();
|
||||
@ -79,12 +80,14 @@ fn select_proto_basic() {
|
||||
fn negotiation_failed() {
|
||||
let _ = env_logger::try_init();
|
||||
|
||||
async fn run(Test {
|
||||
version,
|
||||
listen_protos,
|
||||
dial_protos,
|
||||
dial_payload
|
||||
}: Test) {
|
||||
async fn run(
|
||||
Test {
|
||||
version,
|
||||
listen_protos,
|
||||
dial_protos,
|
||||
dial_payload,
|
||||
}: Test,
|
||||
) {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let listener_addr = listener.local_addr().unwrap();
|
||||
|
||||
@ -93,10 +96,12 @@ fn negotiation_failed() {
|
||||
let io = match listener_select_proto(connec, listen_protos).await {
|
||||
Ok((_, io)) => io,
|
||||
Err(NegotiationError::Failed) => return,
|
||||
Err(NegotiationError::ProtocolError(e)) => panic!("Unexpected protocol error {}", e),
|
||||
Err(NegotiationError::ProtocolError(e)) => {
|
||||
panic!("Unexpected protocol error {}", e)
|
||||
}
|
||||
};
|
||||
match io.complete().await {
|
||||
Err(NegotiationError::Failed) => {},
|
||||
Err(NegotiationError::Failed) => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
});
|
||||
@ -106,14 +111,14 @@ fn negotiation_failed() {
|
||||
let mut io = match dialer_select_proto(connec, dial_protos.into_iter(), version).await {
|
||||
Err(NegotiationError::Failed) => return,
|
||||
Ok((_, io)) => io,
|
||||
Err(_) => panic!()
|
||||
Err(_) => panic!(),
|
||||
};
|
||||
// The dialer may write a payload that is even sent before it
|
||||
// got confirmation of the last proposed protocol, when `V1Lazy`
|
||||
// is used.
|
||||
io.write_all(&dial_payload).await.unwrap();
|
||||
match io.complete().await {
|
||||
Err(NegotiationError::Failed) => {},
|
||||
Err(NegotiationError::Failed) => {}
|
||||
_ => panic!(),
|
||||
}
|
||||
});
|
||||
@ -135,10 +140,10 @@ fn negotiation_failed() {
|
||||
//
|
||||
// The choices here cover the main distinction between a single
|
||||
// and multiple protocols.
|
||||
let protos = vec!{
|
||||
let protos = vec![
|
||||
(vec!["/proto1"], vec!["/proto2"]),
|
||||
(vec!["/proto1", "/proto2"], vec!["/proto3", "/proto4"]),
|
||||
};
|
||||
];
|
||||
|
||||
// The payloads that the dialer sends after "successful" negotiation,
|
||||
// which may be sent even before the dialer got protocol confirmation
|
||||
@ -147,7 +152,7 @@ fn negotiation_failed() {
|
||||
// The choices here cover the specific situations that can arise with
|
||||
// `V1Lazy` and which must nevertheless behave identically to `V1` w.r.t.
|
||||
// the outcome of the negotiation.
|
||||
let payloads = vec!{
|
||||
let payloads = vec![
|
||||
// No payload, in which case all versions should behave identically
|
||||
// in any case, i.e. the baseline test.
|
||||
vec![],
|
||||
@ -155,13 +160,13 @@ fn negotiation_failed() {
|
||||
// `1` as a message length and encounters an invalid message (the
|
||||
// second `1`). The listener is nevertheless expected to fail
|
||||
// negotiation normally, just like with `V1`.
|
||||
vec![1,1],
|
||||
vec![1, 1],
|
||||
// With this payload and `V1Lazy`, the listener interprets the first
|
||||
// `42` as a message length and encounters unexpected EOF trying to
|
||||
// read a message of that length. The listener is nevertheless expected
|
||||
// to fail negotiation normally, just like with `V1`
|
||||
vec![42,1],
|
||||
};
|
||||
vec![42, 1],
|
||||
];
|
||||
|
||||
for (listen_protos, dial_protos) in protos {
|
||||
for dial_payload in payloads.clone() {
|
||||
@ -195,7 +200,8 @@ fn select_proto_parallel() {
|
||||
let connec = TcpStream::connect(&listener_addr).await.unwrap();
|
||||
let protos = vec![b"/proto3", b"/proto2"];
|
||||
let (proto, io) = dialer_select_proto_parallel(connec, protos.into_iter(), version)
|
||||
.await.unwrap();
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(proto, b"/proto2");
|
||||
io.complete().await.unwrap();
|
||||
});
|
||||
@ -226,7 +232,8 @@ fn select_proto_serial() {
|
||||
let connec = TcpStream::connect(&listener_addr).await.unwrap();
|
||||
let protos = vec![b"/proto3", b"/proto2"];
|
||||
let (proto, io) = dialer_select_proto_serial(connec, protos.into_iter(), version)
|
||||
.await.unwrap();
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(proto, b"/proto2");
|
||||
io.complete().await.unwrap();
|
||||
});
|
||||
|
Reference in New Issue
Block a user