Remove tokio-codec dependency from multistream-select. (#1203)

* Remove tokio-codec dependency from multistream-select.

In preparation for the eventual switch from tokio to std futures.

Includes some initial refactoring in preparation for further work
in the context of https://github.com/libp2p/rust-libp2p/issues/659.

* Reduce default buffer sizes.

* Allow more than one frame to be buffered for sending.

* Doc tweaks.

* Remove superfluous (duplicated) Message types.
This commit is contained in:
Roman Borschel 2019-07-29 17:06:23 +02:00 committed by GitHub
parent bcc7c4d349
commit 2fd941122a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 383 additions and 401 deletions

View File

@ -14,9 +14,8 @@ bytes = "0.4"
futures = { version = "0.1" } futures = { version = "0.1" }
log = "0.4" log = "0.4"
smallvec = "0.6" smallvec = "0.6"
tokio-codec = "0.1"
tokio-io = "0.1" tokio-io = "0.1"
unsigned-varint = { version = "0.2.1", features = ["codec"] } unsigned-varint = { version = "0.2.2" }
[dev-dependencies] [dev-dependencies]
tokio = "0.1" tokio = "0.1"

View File

@ -22,19 +22,14 @@
//! `multistream-select` for the dialer. //! `multistream-select` for the dialer.
use futures::{future::Either, prelude::*, stream::StreamFuture}; use futures::{future::Either, prelude::*, stream::StreamFuture};
use crate::protocol::{ use crate::protocol::{Dialer, DialerFuture, Request, Response};
Dialer,
DialerFuture,
DialerToListenerMessage,
ListenerToDialerMessage
};
use log::trace; use log::trace;
use std::mem; use std::mem;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use crate::{Negotiated, ProtocolChoiceError}; use crate::{Negotiated, ProtocolChoiceError};
/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer /// Future, returned by `dialer_select_proto`, which selects a protocol and dialer
/// either sequentially of by considering all protocols in parallel. /// either sequentially or by considering all protocols in parallel.
pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>; pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;
/// Helps selecting a protocol amongst the ones supported. /// Helps selecting a protocol amongst the ones supported.
@ -75,7 +70,10 @@ where
{ {
let protocols = protocols.into_iter(); let protocols = protocols.into_iter();
DialerSelectSeq { DialerSelectSeq {
inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::dial(inner), protocols } inner: DialerSelectSeqState::AwaitDialer {
dialer_fut: Dialer::dial(inner),
protocols
}
} }
} }
@ -148,9 +146,7 @@ where
} }
DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => { DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => {
trace!("sending {:?}", proto_name.as_ref()); trace!("sending {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest { let req = Request::Protocol { name: proto_name.clone() };
name: proto_name.clone()
};
match dialer.start_send(req)? { match dialer.start_send(req)? {
AsyncSink::Ready => { AsyncSink::Ready => {
self.inner = DialerSelectSeqState::FlushProtocol { self.inner = DialerSelectSeqState::FlushProtocol {
@ -204,12 +200,12 @@ where
}; };
trace!("received {:?}", m); trace!("received {:?}", m);
match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? { match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? {
ListenerToDialerMessage::ProtocolAck { ref name } Response::Protocol { ref name }
if name.as_ref() == proto_name.as_ref() => if name.as_ref() == proto_name.as_ref() =>
{ {
return Ok(Async::Ready((proto_name, Negotiated(r.into_inner())))) return Ok(Async::Ready((proto_name, Negotiated(r.into_inner()))))
} }
ListenerToDialerMessage::NotAvailable => { Response::ProtocolNotAvailable => {
let proto_name = protocols.next() let proto_name = protocols.next()
.ok_or(ProtocolChoiceError::NoProtocolFound)?; .ok_or(ProtocolChoiceError::NoProtocolFound)?;
self.inner = DialerSelectSeqState::NextProtocol { self.inner = DialerSelectSeqState::NextProtocol {
@ -244,9 +240,8 @@ where
} }
} }
/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in /// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in
/// parellel, by first requesting the liste of protocols supported by the remote endpoint and /// parallel, by first requesting the list of protocols supported by the remote endpoint and
/// then selecting the most appropriate one by applying a match predicate to the result. /// then selecting the most appropriate one by applying a match predicate to the result.
pub struct DialerSelectPar<R, I> pub struct DialerSelectPar<R, I>
where where
@ -319,7 +314,7 @@ where
} }
DialerSelectParState::ProtocolList { mut dialer, protocols } => { DialerSelectParState::ProtocolList { mut dialer, protocols } => {
trace!("requesting protocols list"); trace!("requesting protocols list");
match dialer.start_send(DialerToListenerMessage::ProtocolsListRequest)? { match dialer.start_send(Request::ListProtocols)? {
AsyncSink::Ready => { AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushListRequest { self.inner = DialerSelectParState::FlushListRequest {
dialer, dialer,
@ -359,15 +354,15 @@ where
Err((e, _)) => return Err(ProtocolChoiceError::from(e)) Err((e, _)) => return Err(ProtocolChoiceError::from(e))
}; };
trace!("protocols list response: {:?}", resp); trace!("protocols list response: {:?}", resp);
let list = let supported =
if let Some(ListenerToDialerMessage::ProtocolsListResponse { list }) = resp { if let Some(Response::SupportedProtocols { protocols }) = resp {
list protocols
} else { } else {
return Err(ProtocolChoiceError::UnexpectedMessage) return Err(ProtocolChoiceError::UnexpectedMessage)
}; };
let mut found = None; let mut found = None;
for local_name in protocols { for local_name in protocols {
for remote_name in &list { for remote_name in &supported {
if remote_name.as_ref() == local_name.as_ref() { if remote_name.as_ref() == local_name.as_ref() {
found = Some(local_name); found = Some(local_name);
break; break;
@ -381,10 +376,8 @@ where
self.inner = DialerSelectParState::Protocol { dialer, proto_name } self.inner = DialerSelectParState::Protocol { dialer, proto_name }
} }
DialerSelectParState::Protocol { mut dialer, proto_name } => { DialerSelectParState::Protocol { mut dialer, proto_name } => {
trace!("requesting protocol: {:?}", proto_name.as_ref()); trace!("Requesting protocol: {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest { let req = Request::Protocol { name: proto_name.clone() };
name: proto_name.clone()
};
match dialer.start_send(req)? { match dialer.start_send(req)? {
AsyncSink::Ready => { AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name } self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }
@ -420,7 +413,7 @@ where
}; };
trace!("received {:?}", resp); trace!("received {:?}", resp);
match resp { match resp {
Some(ListenerToDialerMessage::ProtocolAck { ref name }) Some(Response::Protocol { ref name })
if name.as_ref() == proto_name.as_ref() => if name.as_ref() == proto_name.as_ref() =>
{ {
return Ok(Async::Ready((proto_name, Negotiated(dialer.into_inner())))) return Ok(Async::Ready((proto_name, Negotiated(dialer.into_inner()))))

View File

@ -21,9 +21,8 @@
//! Main `ProtocolChoiceError` error. //! Main `ProtocolChoiceError` error.
use crate::protocol::MultistreamSelectError; use crate::protocol::MultistreamSelectError;
use std::error; use std::error::Error;
use std::fmt; use std::{fmt, io};
use std::io::Error as IoError;
/// Error that can happen when negotiating a protocol with the remote. /// Error that can happen when negotiating a protocol with the remote.
#[derive(Debug)] #[derive(Debug)]
@ -39,21 +38,18 @@ pub enum ProtocolChoiceError {
} }
impl From<MultistreamSelectError> for ProtocolChoiceError { impl From<MultistreamSelectError> for ProtocolChoiceError {
#[inline]
fn from(err: MultistreamSelectError) -> ProtocolChoiceError { fn from(err: MultistreamSelectError) -> ProtocolChoiceError {
ProtocolChoiceError::MultistreamSelectError(err) ProtocolChoiceError::MultistreamSelectError(err)
} }
} }
impl From<IoError> for ProtocolChoiceError { impl From<io::Error> for ProtocolChoiceError {
#[inline] fn from(err: io::Error) -> ProtocolChoiceError {
fn from(err: IoError) -> ProtocolChoiceError {
MultistreamSelectError::from(err).into() MultistreamSelectError::from(err).into()
} }
} }
impl error::Error for ProtocolChoiceError { impl Error for ProtocolChoiceError {
#[inline]
fn description(&self) -> &str { fn description(&self) -> &str {
match *self { match *self {
ProtocolChoiceError::MultistreamSelectError(_) => "error in the protocol", ProtocolChoiceError::MultistreamSelectError(_) => "error in the protocol",
@ -66,7 +62,7 @@ impl error::Error for ProtocolChoiceError {
} }
} }
fn cause(&self) -> Option<&dyn error::Error> { fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self { match *self {
ProtocolChoiceError::MultistreamSelectError(ref err) => Some(err), ProtocolChoiceError::MultistreamSelectError(ref err) => Some(err),
_ => None, _ => None,
@ -75,8 +71,7 @@ impl error::Error for ProtocolChoiceError {
} }
impl fmt::Display for ProtocolChoiceError { impl fmt::Display for ProtocolChoiceError {
#[inline]
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(fmt, "{}", error::Error::description(self)) write!(fmt, "{}", Error::description(self))
} }
} }

View File

@ -18,78 +18,81 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
use bytes::Bytes; use bytes::{Bytes, BytesMut, BufMut};
use futures::{Async, Poll, Sink, StartSend, Stream}; use futures::{try_ready, Async, Poll, Sink, StartSend, Stream, AsyncSink};
use smallvec::SmallVec;
use std::{io, u16}; use std::{io, u16};
use tokio_codec::{Encoder, FramedWrite};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::decode; use unsigned_varint as uvi;
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read const MAX_LEN_BYTES: u16 = 2;
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
const DEFAULT_BUFFER_SIZE: usize = 64;
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` resource to read
/// and write unsigned-varint prefixed frames. /// and write unsigned-varint prefixed frames.
/// ///
/// We purposely only support a frame length of under 64kiB. Frames mostly consist /// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
/// in a short protocol name, which is highly unlikely to be more than 64kiB long. /// frame length). Frames mostly consist in a short protocol name, which is highly
pub struct LengthDelimited<R, C> { /// unlikely to be more than 16KiB long.
// The inner socket where data is pulled from. pub struct LengthDelimited<R> {
inner: FramedWrite<R, C>, /// The inner I/O resource.
// Intermediary buffer where we put either the length of the next frame of data, or the frame inner: R,
// of data itself before it is returned. /// Read buffer for a single incoming unsigned-varint length-delimited frame.
// Must always contain enough space to read data from `inner`. read_buffer: BytesMut,
internal_buffer: SmallVec<[u8; 64]>, /// Write buffer for outgoing unsigned-varint length-delimited frames.
// Number of bytes within `internal_buffer` that contain valid data. write_buffer: BytesMut,
internal_buffer_pos: usize, /// The current read state, alternating between reading a frame
// State of the decoder. /// length and reading a frame payload.
state: State read_state: ReadState,
} }
#[derive(Debug, Copy, Clone, PartialEq, Eq)] #[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum State { enum ReadState {
// We are currently reading the length of the next frame of data. /// We are currently reading the length of the next frame of data.
ReadingLength, ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize },
// We are currently reading the frame of data itself. /// We are currently reading the frame of data itself.
ReadingData { frame_len: u16 }, ReadData { len: u16, pos: usize },
} }
impl<R, C> LengthDelimited<R, C> impl Default for ReadState {
where fn default() -> Self {
R: AsyncWrite, ReadState::ReadLength {
C: Encoder buf: [0; MAX_LEN_BYTES as usize],
{ pos: 0
pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> { }
}
}
impl<R> LengthDelimited<R> {
/// Creates a new I/O resource for reading and writing unsigned-varint
/// length delimited frames.
pub fn new(inner: R) -> LengthDelimited<R> {
LengthDelimited { LengthDelimited {
inner: FramedWrite::new(inner, codec), inner,
internal_buffer: { read_state: ReadState::default(),
let mut v = SmallVec::new(); read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
v.push(0); write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
v
},
internal_buffer_pos: 0,
state: State::ReadingLength
} }
} }
/// Destroys the `LengthDelimited` and returns the underlying socket. /// Destroys the `LengthDelimited` and returns the underlying socket.
/// ///
/// Contrary to its equivalent `tokio_io::codec::length_delimited::FramedRead`, this method is /// This method is guaranteed not to skip any data from the socket.
/// guaranteed not to skip any data from the socket.
/// ///
/// # Panic /// # Panic
/// ///
/// Will panic if called while there is data inside the buffer. **This can only happen if /// Will panic if called while there is data inside the read or write buffer.
/// you call `poll()` manually**. Using this struct as it is intended to be used (i.e. through /// **This can only happen if you call `poll()` manually**. Using this struct
/// the modifiers provided by the `futures` crate) will always leave the object in a state in /// as it is intended to be used (i.e. through the high-level `futures` API)
/// which `into_inner()` will not panic. /// will always leave the object in a state in which `into_inner()` will not panic.
#[inline]
pub fn into_inner(self) -> R { pub fn into_inner(self) -> R {
assert_eq!(self.state, State::ReadingLength); assert!(self.write_buffer.is_empty());
assert_eq!(self.internal_buffer_pos, 0); assert!(self.read_buffer.is_empty());
self.inner.into_inner() self.inner
} }
} }
impl<R, C> Stream for LengthDelimited<R, C> impl<R> Stream for LengthDelimited<R>
where where
R: AsyncRead R: AsyncRead
{ {
@ -98,16 +101,11 @@ where
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop { loop {
debug_assert!(!self.internal_buffer.is_empty()); match &mut self.read_state {
debug_assert!(self.internal_buffer_pos < self.internal_buffer.len()); ReadState::ReadLength { buf, pos } => {
match self.inner.read(&mut buf[*pos .. *pos + 1]) {
match self.state {
State::ReadingLength => {
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
match self.inner.get_mut().read(slice) {
Ok(0) => { Ok(0) => {
// EOF if *pos == 0 {
if self.internal_buffer_pos == 0 {
return Ok(Async::Ready(None)); return Ok(Async::Ready(None));
} else { } else {
return Err(io::ErrorKind::UnexpectedEof.into()); return Err(io::ErrorKind::UnexpectedEof.into());
@ -115,7 +113,7 @@ where
} }
Ok(n) => { Ok(n) => {
debug_assert_eq!(n, 1); debug_assert_eq!(n, 1);
self.internal_buffer_pos += n; *pos += n;
} }
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady); return Ok(Async::NotReady);
@ -125,56 +123,45 @@ where
} }
}; };
debug_assert_eq!(self.internal_buffer.len(), self.internal_buffer_pos); if (buf[*pos - 1] & 0x80) == 0 {
// MSB is not set, indicating the end of the length prefix.
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 { let (len, _) = uvi::decode::u16(buf).map_err(|e| {
// End of length prefix. Most of the time we will switch to reading data,
// but we need to handle a few corner cases first.
let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| {
log::debug!("invalid length prefix: {}", e); log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix") io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?; })?;
if frame_len >= 1 { if len >= 1 {
self.state = State::ReadingData { frame_len }; self.read_state = ReadState::ReadData { len, pos: 0 };
self.internal_buffer.clear(); self.read_buffer.resize(len as usize, 0);
self.internal_buffer.reserve(frame_len as usize);
self.internal_buffer.extend((0..frame_len).map(|_| 0));
self.internal_buffer_pos = 0;
} else { } else {
debug_assert_eq!(frame_len, 0); debug_assert_eq!(len, 0);
self.state = State::ReadingLength; self.read_state = ReadState::default();
self.internal_buffer.clear(); return Ok(Async::Ready(Some(Bytes::new())));
self.internal_buffer.push(0);
self.internal_buffer_pos = 0;
return Ok(Async::Ready(Some(From::from(&[][..]))));
} }
} else if self.internal_buffer_pos >= 2 { } else if *pos == MAX_LEN_BYTES as usize {
// Length prefix is too long. See module doc for info about max frame len. // MSB signals more length bytes but we have already read the maximum.
return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long")); // See the module documentation about the max frame len.
} else { return Err(io::Error::new(
// Prepare for next read. io::ErrorKind::InvalidData,
self.internal_buffer.push(0); "Maximum frame length exceeded"));
} }
} }
State::ReadingData { frame_len } => { ReadState::ReadData { len, pos } => {
let slice = &mut self.internal_buffer[self.internal_buffer_pos..]; match self.inner.read(&mut self.read_buffer[*pos..]) {
match self.inner.get_mut().read(slice) {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()), Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
Ok(n) => self.internal_buffer_pos += n, Ok(n) => *pos += n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { Err(err) =>
return Ok(Async::NotReady) if err.kind() == io::ErrorKind::WouldBlock {
} return Ok(Async::NotReady)
Err(err) => return Err(err) } else {
return Err(err)
}
}; };
if self.internal_buffer_pos >= frame_len as usize { if *pos == *len as usize {
// Finished reading the frame of data. // Finished reading the frame.
self.state = State::ReadingLength; let frame = self.read_buffer.split_off(0).freeze();
let out_data = From::from(&self.internal_buffer[..]); self.read_state = ReadState::default();
self.internal_buffer.clear(); return Ok(Async::Ready(Some(frame)));
self.internal_buffer.push(0);
self.internal_buffer_pos = 0;
return Ok(Async::Ready(Some(out_data)));
} }
} }
} }
@ -182,27 +169,60 @@ where
} }
} }
impl<R, C> Sink for LengthDelimited<R, C> impl<R> Sink for LengthDelimited<R>
where where
R: AsyncWrite, R: AsyncWrite,
C: Encoder
{ {
type SinkItem = <FramedWrite<R, C> as Sink>::SinkItem; type SinkItem = Bytes;
type SinkError = <FramedWrite<R, C> as Sink>::SinkError; type SinkError = io::Error;
#[inline] fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { // Use the maximum frame length also as a (soft) upper limit
self.inner.start_send(item) // for the entire write buffer. The actual (hard) limit is thus
// implied to be roughly 2 * MAX_FRAME_SIZE.
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
self.poll_complete()?;
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
return Ok(AsyncSink::NotReady(msg))
}
}
let len = msg.len() as u16;
if len > MAX_FRAME_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame size exceeded."))
}
let mut uvi_buf = uvi::encode::u16_buffer();
let uvi_len = uvi::encode::u16(len, &mut uvi_buf);
self.write_buffer.reserve(len as usize + uvi_len.len());
self.write_buffer.put(uvi_len);
self.write_buffer.put(msg);
Ok(AsyncSink::Ready)
} }
#[inline]
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
self.inner.poll_complete() while !self.write_buffer.is_empty() {
let n = try_ready!(self.inner.poll_write(&self.write_buffer));
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Failed to write buffered frame."))
}
let _ = self.write_buffer.split_to(n);
}
try_ready!(self.inner.poll_flush());
return Ok(Async::Ready(()));
} }
#[inline]
fn close(&mut self) -> Poll<(), Self::SinkError> { fn close(&mut self) -> Poll<(), Self::SinkError> {
self.inner.close() try_ready!(self.poll_complete());
Ok(self.inner.shutdown()?)
} }
} }
@ -211,12 +231,11 @@ mod tests {
use futures::{Future, Stream}; use futures::{Future, Stream};
use crate::length_delimited::LengthDelimited; use crate::length_delimited::LengthDelimited;
use std::io::{Cursor, ErrorKind}; use std::io::{Cursor, ErrorKind};
use unsigned_varint::codec::UviBytes;
#[test] #[test]
fn basic_read() { fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4]; let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default()); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap(); let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
} }
@ -224,7 +243,7 @@ mod tests {
#[test] #[test]
fn basic_read_two() { fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7]; let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default()); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap(); let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
} }
@ -236,7 +255,7 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>(); let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8]; let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter()); data.extend(frame.clone().into_iter());
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default()); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed let recved = framed
.into_future() .into_future()
.map(|(m, _)| m) .map(|(m, _)| m)
@ -250,7 +269,7 @@ mod tests {
fn packet_len_too_long() { fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1]; let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0)); data.extend((0..16513).map(|_| 0));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default()); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed let recved = framed
.into_future() .into_future()
.map(|(m, _)| m) .map(|(m, _)| m)
@ -267,7 +286,7 @@ mod tests {
#[test] #[test]
fn empty_frames() { fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7]; let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default()); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap(); let recved = framed.collect().wait().unwrap();
assert_eq!( assert_eq!(
recved, recved,
@ -284,7 +303,7 @@ mod tests {
#[test] #[test]
fn unexpected_eof_in_len() { fn unexpected_eof_in_len() {
let data = vec![0x89]; let data = vec![0x89];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default()); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait(); let recved = framed.collect().wait();
if let Err(io_err) = recved { if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
@ -296,7 +315,7 @@ mod tests {
#[test] #[test]
fn unexpected_eof_in_data() { fn unexpected_eof_in_data() {
let data = vec![5]; let data = vec![5];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default()); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait(); let recved = framed.collect().wait();
if let Err(io_err) = recved { if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
@ -308,7 +327,7 @@ mod tests {
#[test] #[test]
fn unexpected_eof_in_data2() { fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7]; let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default()); let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait(); let recved = framed.collect().wait();
if let Err(io_err) = recved { if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof) assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)

View File

@ -21,7 +21,7 @@
//! # Multistream-select //! # Multistream-select
//! //!
//! This crate implements the `multistream-select` protocol, which is the protocol used by libp2p //! This crate implements the `multistream-select` protocol, which is the protocol used by libp2p
//! to negotiate which protocol to use with the remote. //! to negotiate which protocol to use with the remote on a connection or substream.
//! //!
//! > **Note**: This crate is used by the internals of *libp2p*, and it is not required to //! > **Note**: This crate is used by the internals of *libp2p*, and it is not required to
//! > understand it in order to use *libp2p*. //! > understand it in order to use *libp2p*.
@ -76,6 +76,7 @@ mod protocol;
use futures::prelude::*; use futures::prelude::*;
use std::io; use std::io;
use tokio_io::{AsyncRead, AsyncWrite};
pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture}; pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture};
pub use self::error::ProtocolChoiceError; pub use self::error::ProtocolChoiceError;
@ -93,9 +94,9 @@ where
} }
} }
impl<TInner> tokio_io::AsyncRead for Negotiated<TInner> impl<TInner> AsyncRead for Negotiated<TInner>
where where
TInner: tokio_io::AsyncRead TInner: AsyncRead
{ {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool { unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.prepare_uninitialized_buffer(buf) self.0.prepare_uninitialized_buffer(buf)
@ -119,9 +120,9 @@ where
} }
} }
impl<TInner> tokio_io::AsyncWrite for Negotiated<TInner> impl<TInner> AsyncWrite for Negotiated<TInner>
where where
TInner: tokio_io::AsyncWrite TInner: AsyncWrite
{ {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn shutdown(&mut self) -> Poll<(), io::Error> {
self.0.shutdown() self.0.shutdown()

View File

@ -23,10 +23,10 @@
use futures::{prelude::*, sink, stream::StreamFuture}; use futures::{prelude::*, sink, stream::StreamFuture};
use crate::protocol::{ use crate::protocol::{
DialerToListenerMessage, Request,
Response,
Listener, Listener,
ListenerFuture, ListenerFuture,
ListenerToDialerMessage
}; };
use log::{debug, trace}; use log::{debug, trace};
use std::mem; use std::mem;
@ -126,13 +126,13 @@ where
Err((e, _)) => return Err(ProtocolChoiceError::from(e)) Err((e, _)) => return Err(ProtocolChoiceError::from(e))
}; };
match msg { match msg {
Some(DialerToListenerMessage::ProtocolsListRequest) => { Some(Request::ListProtocols) => {
trace!("protocols list response: {:?}", protocols trace!("protocols list response: {:?}", protocols
.into_iter() .into_iter()
.map(|p| p.as_ref().into()) .map(|p| p.as_ref().into())
.collect::<Vec<Vec<u8>>>()); .collect::<Vec<Vec<u8>>>());
let list = protocols.into_iter().collect(); let supported = protocols.into_iter().collect();
let msg = ListenerToDialerMessage::ProtocolsListResponse { list }; let msg = Response::SupportedProtocols { protocols: supported };
let sender = listener.send(msg); let sender = listener.send(msg);
self.inner = ListenerSelectState::Outgoing { self.inner = ListenerSelectState::Outgoing {
sender, sender,
@ -140,12 +140,12 @@ where
outcome: None outcome: None
} }
} }
Some(DialerToListenerMessage::ProtocolRequest { name }) => { Some(Request::Protocol { name }) => {
let mut outcome = None; let mut outcome = None;
let mut send_back = ListenerToDialerMessage::NotAvailable; let mut send_back = Response::ProtocolNotAvailable;
for supported in &protocols { for supported in &protocols {
if name.as_ref() == supported.as_ref() { if name.as_ref() == supported.as_ref() {
send_back = ListenerToDialerMessage::ProtocolAck { send_back = Response::Protocol {
name: supported.clone() name: supported.clone()
}; };
outcome = Some(supported); outcome = Some(supported);

View File

@ -20,23 +20,25 @@
//! Contains the `Dialer` wrapper, which allows raw communications with a listener. //! Contains the `Dialer` wrapper, which allows raw communications with a listener.
use bytes::{BufMut, Bytes, BytesMut}; use super::*;
use bytes::{Bytes, BytesMut};
use crate::length_delimited::LengthDelimited; use crate::length_delimited::LengthDelimited;
use crate::protocol::DialerToListenerMessage; use crate::protocol::{Request, Response, MultistreamSelectError};
use crate::protocol::ListenerToDialerMessage;
use crate::protocol::MultistreamSelectError;
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use futures::{prelude::*, sink, Async, StartSend, try_ready}; use futures::{prelude::*, sink, Async, StartSend, try_ready};
use std::io;
use tokio_codec::Encoder;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::{decode, codec::Uvi}; use std::marker;
use unsigned_varint as uvi;
/// The maximum number of supported protocols that can be processed.
const MAX_PROTOCOLS: usize = 1000;
/// Wraps around a `AsyncRead+AsyncWrite`. /// Wraps around a `AsyncRead+AsyncWrite`.
/// Assumes that we're on the dialer's side. Produces and accepts messages. /// Assumes that we're on the dialer's side. Produces and accepts messages.
pub struct Dialer<R, N> { pub struct Dialer<R, N> {
inner: LengthDelimited<R, MessageEncoder<N>>, inner: LengthDelimited<R>,
handshake_finished: bool handshake_finished: bool,
_protocol_name: marker::PhantomData<N>,
} }
impl<R, N> Dialer<R, N> impl<R, N> Dialer<R, N>
@ -45,15 +47,16 @@ where
N: AsRef<[u8]> N: AsRef<[u8]>
{ {
pub fn dial(inner: R) -> DialerFuture<R, N> { pub fn dial(inner: R) -> DialerFuture<R, N> {
let codec = MessageEncoder(std::marker::PhantomData); let io = LengthDelimited::new(inner);
let sender = LengthDelimited::new(inner, codec); let mut buf = BytesMut::new();
Header::Multistream10.encode(&mut buf);
DialerFuture { DialerFuture {
inner: sender.send(Message::Header) inner: io.send(buf.freeze()),
_protocol_name: marker::PhantomData,
} }
} }
/// Grants back the socket. Typically used after a `ProtocolAck` has been received. /// Grants back the socket. Typically used after a `ProtocolAck` has been received.
#[inline]
pub fn into_inner(self) -> R { pub fn into_inner(self) -> R {
self.inner.into_inner() self.inner.into_inner()
} }
@ -64,24 +67,22 @@ where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
N: AsRef<[u8]> N: AsRef<[u8]>
{ {
type SinkItem = DialerToListenerMessage<N>; type SinkItem = Request<N>;
type SinkError = MultistreamSelectError; type SinkError = MultistreamSelectError;
#[inline] fn start_send(&mut self, request: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { let mut msg = BytesMut::new();
match self.inner.start_send(Message::Body(item))? { request.encode(&mut msg)?;
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)), match self.inner.start_send(msg.freeze())? {
AsyncSink::NotReady(Message::Header) => unreachable!(), AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(request)),
AsyncSink::Ready => Ok(AsyncSink::Ready) AsyncSink::Ready => Ok(AsyncSink::Ready),
} }
} }
#[inline]
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.poll_complete()?) Ok(self.inner.poll_complete()?)
} }
#[inline]
fn close(&mut self) -> Poll<(), Self::SinkError> { fn close(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.close()?) Ok(self.inner.close()?)
} }
@ -91,20 +92,20 @@ impl<R, N> Stream for Dialer<R, N>
where where
R: AsyncRead + AsyncWrite R: AsyncRead + AsyncWrite
{ {
type Item = ListenerToDialerMessage<Bytes>; type Item = Response<Bytes>;
type Error = MultistreamSelectError; type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop { loop {
let mut frame = match self.inner.poll() { let mut msg = match self.inner.poll() {
Ok(Async::Ready(Some(frame))) => frame, Ok(Async::Ready(Some(msg))) => msg,
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
}; };
if !self.handshake_finished { if !self.handshake_finished {
if frame == MULTISTREAM_PROTOCOL_WITH_LF { if msg == MSG_MULTISTREAM_1_0 {
self.handshake_finished = true; self.handshake_finished = true;
continue; continue;
} else { } else {
@ -112,31 +113,31 @@ where
} }
} }
if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') { if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') {
let frame_len = frame.len(); let len = msg.len();
let protocol = frame.split_to(frame_len - 1); let name = msg.split_to(len - 1);
return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck { return Ok(Async::Ready(Some(
name: protocol Response::Protocol { name }
}))); )));
} else if frame == b"na\n"[..] { } else if msg == MSG_PROTOCOL_NA {
return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable))); return Ok(Async::Ready(Some(Response::ProtocolNotAvailable)));
} else { } else {
// A varint number of protocols // A varint number of protocols
let (num_protocols, mut remaining) = decode::usize(&frame)?; let (num_protocols, mut remaining) = uvi::decode::usize(&msg)?;
if num_protocols > 1000 { // TODO: configurable limit if num_protocols > MAX_PROTOCOLS { // TODO: configurable limit
return Err(MultistreamSelectError::VarintParseError("too many protocols".into())) return Err(MultistreamSelectError::TooManyProtocols)
} }
let mut out = Vec::with_capacity(num_protocols); let mut protocols = Vec::with_capacity(num_protocols);
for _ in 0 .. num_protocols { for _ in 0 .. num_protocols {
let (len, rem) = decode::usize(remaining)?; let (len, rem) = uvi::decode::usize(remaining)?;
if len == 0 || len > rem.len() || rem[len - 1] != b'\n' { if len == 0 || len > rem.len() || rem[len - 1] != b'\n' {
return Err(MultistreamSelectError::UnknownMessage) return Err(MultistreamSelectError::UnknownMessage)
} }
out.push(Bytes::from(&rem[.. len - 1])); protocols.push(Bytes::from(&rem[.. len - 1]));
remaining = &rem[len ..] remaining = &rem[len ..]
} }
return Ok(Async::Ready(Some( return Ok(Async::Ready(Some(
ListenerToDialerMessage::ProtocolsListResponse { list: out }, Response::SupportedProtocols { protocols },
))); )));
} }
} }
@ -145,7 +146,8 @@ where
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`. /// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
pub struct DialerFuture<T: AsyncWrite, N: AsRef<[u8]>> { pub struct DialerFuture<T: AsyncWrite, N: AsRef<[u8]>> {
inner: sink::Send<LengthDelimited<T, MessageEncoder<N>>> inner: sink::Send<LengthDelimited<T>>,
_protocol_name: marker::PhantomData<N>,
} }
impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> { impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
@ -154,57 +156,17 @@ impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let inner = try_ready!(self.inner.poll()); let inner = try_ready!(self.inner.poll());
Ok(Async::Ready(Dialer { inner, handshake_finished: false })) Ok(Async::Ready(Dialer {
} inner,
} handshake_finished: false,
_protocol_name: marker::PhantomData,
/// tokio-codec `Encoder` handling `DialerToListenerMessage` values. }))
struct MessageEncoder<N>(std::marker::PhantomData<N>);
enum Message<N> {
Header,
Body(DialerToListenerMessage<N>)
}
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
type Item = Message<N>;
type Error = MultistreamSelectError;
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Header => {
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
Ok(())
}
Message::Body(DialerToListenerMessage::ProtocolRequest { name }) => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
if len > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
Uvi::<usize>::default().encode(len, dest)?;
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Message::Body(DialerToListenerMessage::ProtocolsListRequest) => {
Uvi::<usize>::default().encode(3, dest)?;
dest.reserve(3);
dest.put(&b"ls\n"[..]);
Ok(())
}
}
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError}; use super::*;
use tokio::runtime::current_thread::Runtime; use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream}; use tokio_tcp::{TcpListener, TcpStream};
use futures::Future; use futures::Future;
@ -225,13 +187,13 @@ mod tests {
.from_err() .from_err()
.and_then(move |stream| Dialer::dial(stream)) .and_then(move |stream| Dialer::dial(stream))
.and_then(move |dialer| { .and_then(move |dialer| {
let p = b"invalid_name"; let name = b"invalid_name";
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p }) dialer.send(Request::Protocol { name })
}); });
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
match rt.block_on(server.join(client)) { match rt.block_on(server.join(client)) {
Err(MultistreamSelectError::WrongProtocolName) => (), Err(MultistreamSelectError::InvalidProtocolName) => (),
_ => panic!(), _ => panic!(),
} }
} }

View File

@ -20,7 +20,7 @@
//! Contains the error structs for the low-level protocol handling. //! Contains the error structs for the low-level protocol handling.
use std::error; use std::error::Error;
use std::fmt; use std::fmt;
use std::io; use std::io;
use unsigned_varint::decode; use unsigned_varint::decode;
@ -38,29 +38,25 @@ pub enum MultistreamSelectError {
UnknownMessage, UnknownMessage,
/// Protocol names must always start with `/`, otherwise this error is returned. /// Protocol names must always start with `/`, otherwise this error is returned.
WrongProtocolName, InvalidProtocolName,
/// Failure to parse variable-length integer. /// Too many protocols have been returned by the remote.
// TODO: we don't include the actual error, because that would remove Send from the enum TooManyProtocols,
VarintParseError(String),
} }
impl From<io::Error> for MultistreamSelectError { impl From<io::Error> for MultistreamSelectError {
#[inline]
fn from(err: io::Error) -> MultistreamSelectError { fn from(err: io::Error) -> MultistreamSelectError {
MultistreamSelectError::IoError(err) MultistreamSelectError::IoError(err)
} }
} }
impl From<decode::Error> for MultistreamSelectError { impl From<decode::Error> for MultistreamSelectError {
#[inline]
fn from(err: decode::Error) -> MultistreamSelectError { fn from(err: decode::Error) -> MultistreamSelectError {
MultistreamSelectError::VarintParseError(err.to_string()) Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
} }
} }
impl error::Error for MultistreamSelectError { impl Error for MultistreamSelectError {
#[inline]
fn description(&self) -> &str { fn description(&self) -> &str {
match *self { match *self {
MultistreamSelectError::IoError(_) => "I/O error", MultistreamSelectError::IoError(_) => "I/O error",
@ -68,16 +64,15 @@ impl error::Error for MultistreamSelectError {
"the remote doesn't use the same multistream-select protocol as we do" "the remote doesn't use the same multistream-select protocol as we do"
} }
MultistreamSelectError::UnknownMessage => "received an unknown message from the remote", MultistreamSelectError::UnknownMessage => "received an unknown message from the remote",
MultistreamSelectError::WrongProtocolName => { MultistreamSelectError::InvalidProtocolName => {
"protocol names must always start with `/`, otherwise this error is returned" "protocol names must always start with `/`, otherwise this error is returned"
} }
MultistreamSelectError::VarintParseError(_) => { MultistreamSelectError::TooManyProtocols =>
"failure to parse variable-length integer" "Too many protocols."
}
} }
} }
fn cause(&self) -> Option<&dyn error::Error> { fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self { match *self {
MultistreamSelectError::IoError(ref err) => Some(err), MultistreamSelectError::IoError(ref err) => Some(err),
_ => None, _ => None,
@ -86,8 +81,7 @@ impl error::Error for MultistreamSelectError {
} }
impl fmt::Display for MultistreamSelectError { impl fmt::Display for MultistreamSelectError {
#[inline]
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(fmt, "{}", error::Error::description(self)) write!(fmt, "{}", Error::description(self))
} }
} }

View File

@ -20,23 +20,21 @@
//! Contains the `Listener` wrapper, which allows raw communications with a dialer. //! Contains the `Listener` wrapper, which allows raw communications with a dialer.
use bytes::{BufMut, Bytes, BytesMut}; use super::*;
use bytes::{Bytes, BytesMut};
use crate::length_delimited::LengthDelimited; use crate::length_delimited::LengthDelimited;
use crate::protocol::DialerToListenerMessage; use crate::protocol::{Request, Response, MultistreamSelectError};
use crate::protocol::ListenerToDialerMessage;
use crate::protocol::MultistreamSelectError;
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use futures::{prelude::*, sink, stream::StreamFuture}; use futures::{prelude::*, sink, stream::StreamFuture};
use log::{debug, trace}; use log::{debug, trace};
use std::{io, mem}; use std::{marker, mem};
use tokio_codec::Encoder;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::{encode, codec::Uvi};
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
/// accepts messages. /// accepts messages.
pub struct Listener<R, N> { pub struct Listener<R, N> {
inner: LengthDelimited<R, MessageEncoder<N>> inner: LengthDelimited<R>,
_protocol_name: marker::PhantomData<N>,
} }
impl<R, N> Listener<R, N> impl<R, N> Listener<R, N>
@ -47,16 +45,15 @@ where
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
/// future returns a `Listener`. /// future returns a `Listener`.
pub fn listen(inner: R) -> ListenerFuture<R, N> { pub fn listen(inner: R) -> ListenerFuture<R, N> {
let codec = MessageEncoder(std::marker::PhantomData); let inner = LengthDelimited::new(inner);
let inner = LengthDelimited::new(inner, codec);
ListenerFuture { ListenerFuture {
inner: ListenerFutureState::Await { inner: inner.into_future() } inner: ListenerFutureState::Await { inner: inner.into_future() },
_protocol_name: marker::PhantomData,
} }
} }
/// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a /// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a
/// `ProtocolAck` has been sent back. /// `ProtocolAck` has been sent back.
#[inline]
pub fn into_inner(self) -> R { pub fn into_inner(self) -> R {
self.inner.into_inner() self.inner.into_inner()
} }
@ -67,24 +64,22 @@ where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
N: AsRef<[u8]> N: AsRef<[u8]>
{ {
type SinkItem = ListenerToDialerMessage<N>; type SinkItem = Response<N>;
type SinkError = MultistreamSelectError; type SinkError = MultistreamSelectError;
#[inline] fn start_send(&mut self, response: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { let mut msg = BytesMut::new();
match self.inner.start_send(Message::Body(item))? { response.encode(&mut msg)?;
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)), match self.inner.start_send(msg.freeze())? {
AsyncSink::NotReady(Message::Header) => unreachable!(), AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(response)),
AsyncSink::Ready => Ok(AsyncSink::Ready) AsyncSink::Ready => Ok(AsyncSink::Ready)
} }
} }
#[inline]
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.poll_complete()?) Ok(self.inner.poll_complete()?)
} }
#[inline]
fn close(&mut self) -> Poll<(), Self::SinkError> { fn close(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.close()?) Ok(self.inner.close()?)
} }
@ -94,26 +89,26 @@ impl<R, N> Stream for Listener<R, N>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
{ {
type Item = DialerToListenerMessage<Bytes>; type Item = Request<Bytes>;
type Error = MultistreamSelectError; type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let mut frame = match self.inner.poll() { let mut msg = match self.inner.poll() {
Ok(Async::Ready(Some(frame))) => frame, Ok(Async::Ready(Some(msg))) => msg,
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)), Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
Ok(Async::NotReady) => return Ok(Async::NotReady), Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => return Err(err.into()), Err(err) => return Err(err.into()),
}; };
if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') { if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') {
let frame_len = frame.len(); let len = msg.len();
let protocol = frame.split_to(frame_len - 1); let name = msg.split_to(len - 1);
Ok(Async::Ready(Some( Ok(Async::Ready(Some(
DialerToListenerMessage::ProtocolRequest { name: protocol }, Request::Protocol { name },
))) )))
} else if frame == b"ls\n"[..] { } else if msg == MSG_LS {
Ok(Async::Ready(Some( Ok(Async::Ready(Some(
DialerToListenerMessage::ProtocolsListRequest, Request::ListProtocols,
))) )))
} else { } else {
Err(MultistreamSelectError::UnknownMessage) Err(MultistreamSelectError::UnknownMessage)
@ -124,16 +119,17 @@ where
/// Future, returned by `Listener::new` which performs the handshake and returns /// Future, returned by `Listener::new` which performs the handshake and returns
/// the `Listener` if successful. /// the `Listener` if successful.
pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> { pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N> {
inner: ListenerFutureState<T, N> inner: ListenerFutureState<T>,
_protocol_name: marker::PhantomData<N>,
} }
enum ListenerFutureState<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> { enum ListenerFutureState<T: AsyncRead + AsyncWrite> {
Await { Await {
inner: StreamFuture<LengthDelimited<T, MessageEncoder<N>>> inner: StreamFuture<LengthDelimited<T>>
}, },
Reply { Reply {
sender: sink::Send<LengthDelimited<T, MessageEncoder<N>>> sender: sink::Send<LengthDelimited<T>>
}, },
Undefined Undefined
} }
@ -155,12 +151,14 @@ impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N>
} }
Err((e, _)) => return Err(MultistreamSelectError::from(e)) Err((e, _)) => return Err(MultistreamSelectError::from(e))
}; };
if msg.as_ref().map(|b| &b[..]) != Some(MULTISTREAM_PROTOCOL_WITH_LF) { if msg.as_ref().map(|b| &b[..]) != Some(MSG_MULTISTREAM_1_0) {
debug!("failed handshake; received: {:?}", msg); debug!("Unexpected message: {:?}", msg);
return Err(MultistreamSelectError::FailedHandshake) return Err(MultistreamSelectError::FailedHandshake)
} }
trace!("sending back /multistream/<version> to finish the handshake"); trace!("sending back /multistream/<version> to finish the handshake");
let sender = socket.send(Message::Header); let mut frame = BytesMut::new();
Header::Multistream10.encode(&mut frame);
let sender = socket.send(frame.freeze());
self.inner = ListenerFutureState::Reply { sender } self.inner = ListenerFutureState::Reply { sender }
} }
ListenerFutureState::Reply { mut sender } => { ListenerFutureState::Reply { mut sender } => {
@ -171,70 +169,13 @@ impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N>
return Ok(Async::NotReady) return Ok(Async::NotReady)
} }
}; };
return Ok(Async::Ready(Listener { inner: listener })) return Ok(Async::Ready(Listener {
inner: listener,
_protocol_name: marker::PhantomData
}))
} }
ListenerFutureState::Undefined => panic!("ListenerFutureState::poll called after completion") ListenerFutureState::Undefined =>
} panic!("ListenerFutureState::poll called after completion")
}
}
}
/// tokio-codec `Encoder` handling `ListenerToDialerMessage` values.
struct MessageEncoder<N>(std::marker::PhantomData<N>);
enum Message<N> {
Header,
Body(ListenerToDialerMessage<N>)
}
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
type Item = Message<N>;
type Error = MultistreamSelectError;
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Header => {
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
Ok(())
}
Message::Body(ListenerToDialerMessage::ProtocolAck { name }) => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
if len > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
Uvi::<usize>::default().encode(len, dest)?;
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Message::Body(ListenerToDialerMessage::ProtocolsListResponse { list }) => {
let mut buf = encode::usize_buffer();
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
for e in &list {
if e.as_ref().len() + 1 > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
out_msg.extend(encode::usize(e.as_ref().len() + 1, &mut buf)); // +1 for '\n'
out_msg.extend_from_slice(e.as_ref());
out_msg.push(b'\n')
}
let len = encode::usize(out_msg.len(), &mut buf);
dest.reserve(len.len() + out_msg.len());
dest.put(len);
dest.put(out_msg);
Ok(())
}
Message::Body(ListenerToDialerMessage::NotAvailable) => {
Uvi::<usize>::default().encode(3, dest)?;
dest.reserve(3);
dest.put(&b"na\n"[..]);
Ok(())
} }
} }
} }
@ -242,12 +183,12 @@ impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use tokio::runtime::current_thread::Runtime; use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream}; use tokio_tcp::{TcpListener, TcpStream};
use bytes::Bytes; use bytes::Bytes;
use futures::Future; use futures::Future;
use futures::{Sink, Stream}; use futures::{Sink, Stream};
use crate::protocol::{Dialer, Listener, ListenerToDialerMessage, MultistreamSelectError};
#[test] #[test]
fn wrong_proto_name() { fn wrong_proto_name() {
@ -260,8 +201,8 @@ mod tests {
.map_err(|(e, _)| e.into()) .map_err(|(e, _)| e.into())
.and_then(move |(connec, _)| Listener::listen(connec.unwrap())) .and_then(move |(connec, _)| Listener::listen(connec.unwrap()))
.and_then(|listener| { .and_then(|listener| {
let proto_name = Bytes::from("invalid-proto"); let name = Bytes::from("invalid-proto");
listener.send(ListenerToDialerMessage::ProtocolAck { name: proto_name }) listener.send(Response::Protocol { name })
}); });
let client = TcpStream::connect(&listener_addr) let client = TcpStream::connect(&listener_addr)
@ -270,7 +211,7 @@ mod tests {
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
match rt.block_on(server.join(client)) { match rt.block_on(server.join(client)) {
Err(MultistreamSelectError::WrongProtocolName) => (), Err(MultistreamSelectError::InvalidProtocolName) => (),
_ => panic!(), _ => panic!(),
} }
} }

View File

@ -20,47 +20,125 @@
//! Contains lower-level structs to handle the multistream protocol. //! Contains lower-level structs to handle the multistream protocol.
const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
const MSG_PROTOCOL_NA: &[u8] = b"na\n";
const MSG_LS: &[u8] = b"ls\n";
mod dialer; mod dialer;
mod error; mod error;
mod listener; mod listener;
const MULTISTREAM_PROTOCOL_WITH_LF: &[u8] = b"/multistream/1.0.0\n";
pub use self::dialer::{Dialer, DialerFuture}; pub use self::dialer::{Dialer, DialerFuture};
pub use self::error::MultistreamSelectError; pub use self::error::MultistreamSelectError;
pub use self::listener::{Listener, ListenerFuture}; pub use self::listener::{Listener, ListenerFuture};
use bytes::{BytesMut, BufMut};
use unsigned_varint as uvi;
pub enum Header {
Multistream10
}
impl Header {
fn encode(&self, dest: &mut BytesMut) {
match self {
Header::Multistream10 => {
dest.reserve(MSG_MULTISTREAM_1_0.len());
dest.put(MSG_MULTISTREAM_1_0);
}
}
}
}
/// Message sent from the dialer to the listener. /// Message sent from the dialer to the listener.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum DialerToListenerMessage<N> { pub enum Request<N> {
/// The dialer wants us to use a protocol. /// The dialer wants us to use a protocol.
/// ///
/// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start /// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start
/// communicating in the new protocol. /// communicating in the new protocol.
ProtocolRequest { Protocol {
/// Name of the protocol. /// Name of the protocol.
name: N name: N
}, },
/// The dialer requested the list of protocols that the listener supports. /// The dialer requested the list of protocols that the listener supports.
ProtocolsListRequest, ListProtocols,
} }
impl<N: AsRef<[u8]>> Request<N> {
fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> {
match self {
Request::Protocol { name } => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::InvalidProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Request::ListProtocols => {
dest.reserve(MSG_LS.len());
dest.put(MSG_LS);
Ok(())
}
}
}
}
/// Message sent from the listener to the dialer. /// Message sent from the listener to the dialer.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum ListenerToDialerMessage<N> { pub enum Response<N> {
/// The protocol requested by the dialer is accepted. The socket immediately starts using the /// The protocol requested by the dialer is accepted. The socket immediately starts using the
/// new protocol. /// new protocol.
ProtocolAck { name: N }, Protocol { name: N },
/// The protocol requested by the dialer is not supported or available. /// The protocol requested by the dialer is not supported or available.
NotAvailable, ProtocolNotAvailable,
/// Response to the request for the list of protocols. /// Response to the request for the list of protocols.
ProtocolsListResponse { SupportedProtocols {
/// The list of protocols. /// The list of protocols.
// TODO: use some sort of iterator // TODO: use some sort of iterator
list: Vec<N>, protocols: Vec<N>,
}, },
} }
impl<N: AsRef<[u8]>> Response<N> {
fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> {
match self {
Response::Protocol { name } => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::InvalidProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Response::SupportedProtocols { protocols } => {
let mut buf = uvi::encode::usize_buffer();
let mut out_msg = Vec::from(uvi::encode::usize(protocols.len(), &mut buf));
for p in protocols {
out_msg.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n'
out_msg.extend_from_slice(p.as_ref());
out_msg.push(b'\n')
}
dest.reserve(out_msg.len());
dest.put(out_msg);
Ok(())
}
Response::ProtocolNotAvailable => {
dest.reserve(MSG_PROTOCOL_NA.len());
dest.put(MSG_PROTOCOL_NA);
Ok(())
}
}
}
}

View File

@ -24,7 +24,7 @@
use crate::ProtocolChoiceError; use crate::ProtocolChoiceError;
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial}; use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage}; use crate::protocol::{Dialer, Request, Listener, Response};
use crate::{dialer_select_proto, listener_select_proto}; use crate::{dialer_select_proto, listener_select_proto};
use futures::prelude::*; use futures::prelude::*;
use tokio::runtime::current_thread::Runtime; use tokio::runtime::current_thread::Runtime;
@ -56,23 +56,23 @@ fn negotiate_with_self_succeeds() {
.and_then(|l| l.into_future().map_err(|(e, _)| e)) .and_then(|l| l.into_future().map_err(|(e, _)| e))
.and_then(|(msg, rest)| { .and_then(|(msg, rest)| {
let proto = match msg { let proto = match msg {
Some(DialerToListenerMessage::ProtocolRequest { name }) => name, Some(Request::Protocol { name }) => name,
_ => panic!(), _ => panic!(),
}; };
rest.send(ListenerToDialerMessage::ProtocolAck { name: proto }) rest.send(Response::Protocol { name: proto })
}); });
let client = TcpStream::connect(&listener_addr) let client = TcpStream::connect(&listener_addr)
.from_err() .from_err()
.and_then(move |stream| Dialer::dial(stream)) .and_then(move |stream| Dialer::dial(stream))
.and_then(move |dialer| { .and_then(move |dialer| {
let p = b"/hello/1.0.0"; let name = b"/hello/1.0.0";
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p }) dialer.send(Request::Protocol { name })
}) })
.and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e)) .and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e))
.and_then(move |(msg, _)| { .and_then(move |(msg, _)| {
let proto = match msg { let proto = match msg {
Some(ListenerToDialerMessage::ProtocolAck { name }) => name, Some(Response::Protocol { name }) => name,
_ => panic!(), _ => panic!(),
}; };
assert_eq!(proto, "/hello/1.0.0"); assert_eq!(proto, "/hello/1.0.0");