Remove tokio-codec dependency from multistream-select. (#1203)

* Remove tokio-codec dependency from multistream-select.

In preparation for the eventual switch from tokio to std futures.

Includes some initial refactoring in preparation for further work
in the context of https://github.com/libp2p/rust-libp2p/issues/659.

* Reduce default buffer sizes.

* Allow more than one frame to be buffered for sending.

* Doc tweaks.

* Remove superfluous (duplicated) Message types.
This commit is contained in:
Roman Borschel 2019-07-29 17:06:23 +02:00 committed by GitHub
parent bcc7c4d349
commit 2fd941122a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 383 additions and 401 deletions

View File

@ -14,9 +14,8 @@ bytes = "0.4"
futures = { version = "0.1" }
log = "0.4"
smallvec = "0.6"
tokio-codec = "0.1"
tokio-io = "0.1"
unsigned-varint = { version = "0.2.1", features = ["codec"] }
unsigned-varint = { version = "0.2.2" }
[dev-dependencies]
tokio = "0.1"

View File

@ -22,19 +22,14 @@
//! `multistream-select` for the dialer.
use futures::{future::Either, prelude::*, stream::StreamFuture};
use crate::protocol::{
Dialer,
DialerFuture,
DialerToListenerMessage,
ListenerToDialerMessage
};
use crate::protocol::{Dialer, DialerFuture, Request, Response};
use log::trace;
use std::mem;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{Negotiated, ProtocolChoiceError};
/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer
/// either sequentially of by considering all protocols in parallel.
/// either sequentially or by considering all protocols in parallel.
pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;
/// Helps selecting a protocol amongst the ones supported.
@ -75,7 +70,10 @@ where
{
let protocols = protocols.into_iter();
DialerSelectSeq {
inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::dial(inner), protocols }
inner: DialerSelectSeqState::AwaitDialer {
dialer_fut: Dialer::dial(inner),
protocols
}
}
}
@ -148,9 +146,7 @@ where
}
DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => {
trace!("sending {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest {
name: proto_name.clone()
};
let req = Request::Protocol { name: proto_name.clone() };
match dialer.start_send(req)? {
AsyncSink::Ready => {
self.inner = DialerSelectSeqState::FlushProtocol {
@ -204,12 +200,12 @@ where
};
trace!("received {:?}", m);
match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? {
ListenerToDialerMessage::ProtocolAck { ref name }
Response::Protocol { ref name }
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, Negotiated(r.into_inner()))))
}
ListenerToDialerMessage::NotAvailable => {
Response::ProtocolNotAvailable => {
let proto_name = protocols.next()
.ok_or(ProtocolChoiceError::NoProtocolFound)?;
self.inner = DialerSelectSeqState::NextProtocol {
@ -244,9 +240,8 @@ where
}
}
/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in
/// parellel, by first requesting the liste of protocols supported by the remote endpoint and
/// parallel, by first requesting the list of protocols supported by the remote endpoint and
/// then selecting the most appropriate one by applying a match predicate to the result.
pub struct DialerSelectPar<R, I>
where
@ -319,7 +314,7 @@ where
}
DialerSelectParState::ProtocolList { mut dialer, protocols } => {
trace!("requesting protocols list");
match dialer.start_send(DialerToListenerMessage::ProtocolsListRequest)? {
match dialer.start_send(Request::ListProtocols)? {
AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushListRequest {
dialer,
@ -359,15 +354,15 @@ where
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
};
trace!("protocols list response: {:?}", resp);
let list =
if let Some(ListenerToDialerMessage::ProtocolsListResponse { list }) = resp {
list
let supported =
if let Some(Response::SupportedProtocols { protocols }) = resp {
protocols
} else {
return Err(ProtocolChoiceError::UnexpectedMessage)
};
let mut found = None;
for local_name in protocols {
for remote_name in &list {
for remote_name in &supported {
if remote_name.as_ref() == local_name.as_ref() {
found = Some(local_name);
break;
@ -381,10 +376,8 @@ where
self.inner = DialerSelectParState::Protocol { dialer, proto_name }
}
DialerSelectParState::Protocol { mut dialer, proto_name } => {
trace!("requesting protocol: {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest {
name: proto_name.clone()
};
trace!("Requesting protocol: {:?}", proto_name.as_ref());
let req = Request::Protocol { name: proto_name.clone() };
match dialer.start_send(req)? {
AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }
@ -420,7 +413,7 @@ where
};
trace!("received {:?}", resp);
match resp {
Some(ListenerToDialerMessage::ProtocolAck { ref name })
Some(Response::Protocol { ref name })
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, Negotiated(dialer.into_inner()))))

View File

@ -21,9 +21,8 @@
//! Main `ProtocolChoiceError` error.
use crate::protocol::MultistreamSelectError;
use std::error;
use std::fmt;
use std::io::Error as IoError;
use std::error::Error;
use std::{fmt, io};
/// Error that can happen when negotiating a protocol with the remote.
#[derive(Debug)]
@ -39,21 +38,18 @@ pub enum ProtocolChoiceError {
}
impl From<MultistreamSelectError> for ProtocolChoiceError {
#[inline]
fn from(err: MultistreamSelectError) -> ProtocolChoiceError {
ProtocolChoiceError::MultistreamSelectError(err)
}
}
impl From<IoError> for ProtocolChoiceError {
#[inline]
fn from(err: IoError) -> ProtocolChoiceError {
impl From<io::Error> for ProtocolChoiceError {
fn from(err: io::Error) -> ProtocolChoiceError {
MultistreamSelectError::from(err).into()
}
}
impl error::Error for ProtocolChoiceError {
#[inline]
impl Error for ProtocolChoiceError {
fn description(&self) -> &str {
match *self {
ProtocolChoiceError::MultistreamSelectError(_) => "error in the protocol",
@ -66,7 +62,7 @@ impl error::Error for ProtocolChoiceError {
}
}
fn cause(&self) -> Option<&dyn error::Error> {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
ProtocolChoiceError::MultistreamSelectError(ref err) => Some(err),
_ => None,
@ -75,8 +71,7 @@ impl error::Error for ProtocolChoiceError {
}
impl fmt::Display for ProtocolChoiceError {
#[inline]
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(fmt, "{}", error::Error::description(self))
write!(fmt, "{}", Error::description(self))
}
}

View File

@ -18,78 +18,81 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use bytes::Bytes;
use futures::{Async, Poll, Sink, StartSend, Stream};
use smallvec::SmallVec;
use bytes::{Bytes, BytesMut, BufMut};
use futures::{try_ready, Async, Poll, Sink, StartSend, Stream, AsyncSink};
use std::{io, u16};
use tokio_codec::{Encoder, FramedWrite};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::decode;
use unsigned_varint as uvi;
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
const MAX_LEN_BYTES: u16 = 2;
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
const DEFAULT_BUFFER_SIZE: usize = 64;
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` resource to read
/// and write unsigned-varint prefixed frames.
///
/// We purposely only support a frame length of under 64kiB. Frames mostly consist
/// in a short protocol name, which is highly unlikely to be more than 64kiB long.
pub struct LengthDelimited<R, C> {
// The inner socket where data is pulled from.
inner: FramedWrite<R, C>,
// Intermediary buffer where we put either the length of the next frame of data, or the frame
// of data itself before it is returned.
// Must always contain enough space to read data from `inner`.
internal_buffer: SmallVec<[u8; 64]>,
// Number of bytes within `internal_buffer` that contain valid data.
internal_buffer_pos: usize,
// State of the decoder.
state: State
/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
/// frame length). Frames mostly consist in a short protocol name, which is highly
/// unlikely to be more than 16KiB long.
pub struct LengthDelimited<R> {
/// The inner I/O resource.
inner: R,
/// Read buffer for a single incoming unsigned-varint length-delimited frame.
read_buffer: BytesMut,
/// Write buffer for outgoing unsigned-varint length-delimited frames.
write_buffer: BytesMut,
/// The current read state, alternating between reading a frame
/// length and reading a frame payload.
read_state: ReadState,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
enum State {
// We are currently reading the length of the next frame of data.
ReadingLength,
// We are currently reading the frame of data itself.
ReadingData { frame_len: u16 },
enum ReadState {
/// We are currently reading the length of the next frame of data.
ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize },
/// We are currently reading the frame of data itself.
ReadData { len: u16, pos: usize },
}
impl<R, C> LengthDelimited<R, C>
where
R: AsyncWrite,
C: Encoder
{
pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
impl Default for ReadState {
fn default() -> Self {
ReadState::ReadLength {
buf: [0; MAX_LEN_BYTES as usize],
pos: 0
}
}
}
impl<R> LengthDelimited<R> {
/// Creates a new I/O resource for reading and writing unsigned-varint
/// length delimited frames.
pub fn new(inner: R) -> LengthDelimited<R> {
LengthDelimited {
inner: FramedWrite::new(inner, codec),
internal_buffer: {
let mut v = SmallVec::new();
v.push(0);
v
},
internal_buffer_pos: 0,
state: State::ReadingLength
inner,
read_state: ReadState::default(),
read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
}
}
/// Destroys the `LengthDelimited` and returns the underlying socket.
///
/// Contrary to its equivalent `tokio_io::codec::length_delimited::FramedRead`, this method is
/// guaranteed not to skip any data from the socket.
/// This method is guaranteed not to skip any data from the socket.
///
/// # Panic
///
/// Will panic if called while there is data inside the buffer. **This can only happen if
/// you call `poll()` manually**. Using this struct as it is intended to be used (i.e. through
/// the modifiers provided by the `futures` crate) will always leave the object in a state in
/// which `into_inner()` will not panic.
#[inline]
/// Will panic if called while there is data inside the read or write buffer.
/// **This can only happen if you call `poll()` manually**. Using this struct
/// as it is intended to be used (i.e. through the high-level `futures` API)
/// will always leave the object in a state in which `into_inner()` will not panic.
pub fn into_inner(self) -> R {
assert_eq!(self.state, State::ReadingLength);
assert_eq!(self.internal_buffer_pos, 0);
self.inner.into_inner()
assert!(self.write_buffer.is_empty());
assert!(self.read_buffer.is_empty());
self.inner
}
}
impl<R, C> Stream for LengthDelimited<R, C>
impl<R> Stream for LengthDelimited<R>
where
R: AsyncRead
{
@ -98,16 +101,11 @@ where
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
debug_assert!(!self.internal_buffer.is_empty());
debug_assert!(self.internal_buffer_pos < self.internal_buffer.len());
match self.state {
State::ReadingLength => {
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
match self.inner.get_mut().read(slice) {
match &mut self.read_state {
ReadState::ReadLength { buf, pos } => {
match self.inner.read(&mut buf[*pos .. *pos + 1]) {
Ok(0) => {
// EOF
if self.internal_buffer_pos == 0 {
if *pos == 0 {
return Ok(Async::Ready(None));
} else {
return Err(io::ErrorKind::UnexpectedEof.into());
@ -115,7 +113,7 @@ where
}
Ok(n) => {
debug_assert_eq!(n, 1);
self.internal_buffer_pos += n;
*pos += n;
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady);
@ -125,56 +123,45 @@ where
}
};
debug_assert_eq!(self.internal_buffer.len(), self.internal_buffer_pos);
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 {
// End of length prefix. Most of the time we will switch to reading data,
// but we need to handle a few corner cases first.
let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| {
if (buf[*pos - 1] & 0x80) == 0 {
// MSB is not set, indicating the end of the length prefix.
let (len, _) = uvi::decode::u16(buf).map_err(|e| {
log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;
if frame_len >= 1 {
self.state = State::ReadingData { frame_len };
self.internal_buffer.clear();
self.internal_buffer.reserve(frame_len as usize);
self.internal_buffer.extend((0..frame_len).map(|_| 0));
self.internal_buffer_pos = 0;
if len >= 1 {
self.read_state = ReadState::ReadData { len, pos: 0 };
self.read_buffer.resize(len as usize, 0);
} else {
debug_assert_eq!(frame_len, 0);
self.state = State::ReadingLength;
self.internal_buffer.clear();
self.internal_buffer.push(0);
self.internal_buffer_pos = 0;
return Ok(Async::Ready(Some(From::from(&[][..]))));
debug_assert_eq!(len, 0);
self.read_state = ReadState::default();
return Ok(Async::Ready(Some(Bytes::new())));
}
} else if self.internal_buffer_pos >= 2 {
// Length prefix is too long. See module doc for info about max frame len.
return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long"));
} else {
// Prepare for next read.
self.internal_buffer.push(0);
} else if *pos == MAX_LEN_BYTES as usize {
// MSB signals more length bytes but we have already read the maximum.
// See the module documentation about the max frame len.
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame length exceeded"));
}
}
State::ReadingData { frame_len } => {
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
match self.inner.get_mut().read(slice) {
ReadState::ReadData { len, pos } => {
match self.inner.read(&mut self.read_buffer[*pos..]) {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
Ok(n) => self.internal_buffer_pos += n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady)
}
Err(err) => return Err(err)
Ok(n) => *pos += n,
Err(err) =>
if err.kind() == io::ErrorKind::WouldBlock {
return Ok(Async::NotReady)
} else {
return Err(err)
}
};
if self.internal_buffer_pos >= frame_len as usize {
// Finished reading the frame of data.
self.state = State::ReadingLength;
let out_data = From::from(&self.internal_buffer[..]);
self.internal_buffer.clear();
self.internal_buffer.push(0);
self.internal_buffer_pos = 0;
return Ok(Async::Ready(Some(out_data)));
if *pos == *len as usize {
// Finished reading the frame.
let frame = self.read_buffer.split_off(0).freeze();
self.read_state = ReadState::default();
return Ok(Async::Ready(Some(frame)));
}
}
}
@ -182,27 +169,60 @@ where
}
}
impl<R, C> Sink for LengthDelimited<R, C>
impl<R> Sink for LengthDelimited<R>
where
R: AsyncWrite,
C: Encoder
{
type SinkItem = <FramedWrite<R, C> as Sink>::SinkItem;
type SinkError = <FramedWrite<R, C> as Sink>::SinkError;
type SinkItem = Bytes;
type SinkError = io::Error;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
self.inner.start_send(item)
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
// Use the maximum frame length also as a (soft) upper limit
// for the entire write buffer. The actual (hard) limit is thus
// implied to be roughly 2 * MAX_FRAME_SIZE.
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
self.poll_complete()?;
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
return Ok(AsyncSink::NotReady(msg))
}
}
let len = msg.len() as u16;
if len > MAX_FRAME_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Maximum frame size exceeded."))
}
let mut uvi_buf = uvi::encode::u16_buffer();
let uvi_len = uvi::encode::u16(len, &mut uvi_buf);
self.write_buffer.reserve(len as usize + uvi_len.len());
self.write_buffer.put(uvi_len);
self.write_buffer.put(msg);
Ok(AsyncSink::Ready)
}
#[inline]
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
self.inner.poll_complete()
while !self.write_buffer.is_empty() {
let n = try_ready!(self.inner.poll_write(&self.write_buffer));
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Failed to write buffered frame."))
}
let _ = self.write_buffer.split_to(n);
}
try_ready!(self.inner.poll_flush());
return Ok(Async::Ready(()));
}
#[inline]
fn close(&mut self) -> Poll<(), Self::SinkError> {
self.inner.close()
try_ready!(self.poll_complete());
Ok(self.inner.shutdown()?)
}
}
@ -211,12 +231,11 @@ mod tests {
use futures::{Future, Stream};
use crate::length_delimited::LengthDelimited;
use std::io::{Cursor, ErrorKind};
use unsigned_varint::codec::UviBytes;
#[test]
fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
}
@ -224,7 +243,7 @@ mod tests {
#[test]
fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
}
@ -236,7 +255,7 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter());
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed
.into_future()
.map(|(m, _)| m)
@ -250,7 +269,7 @@ mod tests {
fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed
.into_future()
.map(|(m, _)| m)
@ -267,7 +286,7 @@ mod tests {
#[test]
fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
assert_eq!(
recved,
@ -284,7 +303,7 @@ mod tests {
#[test]
fn unexpected_eof_in_len() {
let data = vec![0x89];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait();
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
@ -296,7 +315,7 @@ mod tests {
#[test]
fn unexpected_eof_in_data() {
let data = vec![5];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait();
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
@ -308,7 +327,7 @@ mod tests {
#[test]
fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let framed = LengthDelimited::new(Cursor::new(data));
let recved = framed.collect().wait();
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)

View File

@ -21,7 +21,7 @@
//! # Multistream-select
//!
//! This crate implements the `multistream-select` protocol, which is the protocol used by libp2p
//! to negotiate which protocol to use with the remote.
//! to negotiate which protocol to use with the remote on a connection or substream.
//!
//! > **Note**: This crate is used by the internals of *libp2p*, and it is not required to
//! > understand it in order to use *libp2p*.
@ -76,6 +76,7 @@ mod protocol;
use futures::prelude::*;
use std::io;
use tokio_io::{AsyncRead, AsyncWrite};
pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture};
pub use self::error::ProtocolChoiceError;
@ -93,9 +94,9 @@ where
}
}
impl<TInner> tokio_io::AsyncRead for Negotiated<TInner>
impl<TInner> AsyncRead for Negotiated<TInner>
where
TInner: tokio_io::AsyncRead
TInner: AsyncRead
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
@ -119,9 +120,9 @@ where
}
}
impl<TInner> tokio_io::AsyncWrite for Negotiated<TInner>
impl<TInner> AsyncWrite for Negotiated<TInner>
where
TInner: tokio_io::AsyncWrite
TInner: AsyncWrite
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.0.shutdown()

View File

@ -23,10 +23,10 @@
use futures::{prelude::*, sink, stream::StreamFuture};
use crate::protocol::{
DialerToListenerMessage,
Request,
Response,
Listener,
ListenerFuture,
ListenerToDialerMessage
};
use log::{debug, trace};
use std::mem;
@ -126,13 +126,13 @@ where
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
};
match msg {
Some(DialerToListenerMessage::ProtocolsListRequest) => {
Some(Request::ListProtocols) => {
trace!("protocols list response: {:?}", protocols
.into_iter()
.map(|p| p.as_ref().into())
.collect::<Vec<Vec<u8>>>());
let list = protocols.into_iter().collect();
let msg = ListenerToDialerMessage::ProtocolsListResponse { list };
let supported = protocols.into_iter().collect();
let msg = Response::SupportedProtocols { protocols: supported };
let sender = listener.send(msg);
self.inner = ListenerSelectState::Outgoing {
sender,
@ -140,12 +140,12 @@ where
outcome: None
}
}
Some(DialerToListenerMessage::ProtocolRequest { name }) => {
Some(Request::Protocol { name }) => {
let mut outcome = None;
let mut send_back = ListenerToDialerMessage::NotAvailable;
let mut send_back = Response::ProtocolNotAvailable;
for supported in &protocols {
if name.as_ref() == supported.as_ref() {
send_back = ListenerToDialerMessage::ProtocolAck {
send_back = Response::Protocol {
name: supported.clone()
};
outcome = Some(supported);

View File

@ -20,23 +20,25 @@
//! Contains the `Dialer` wrapper, which allows raw communications with a listener.
use bytes::{BufMut, Bytes, BytesMut};
use super::*;
use bytes::{Bytes, BytesMut};
use crate::length_delimited::LengthDelimited;
use crate::protocol::DialerToListenerMessage;
use crate::protocol::ListenerToDialerMessage;
use crate::protocol::MultistreamSelectError;
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use crate::protocol::{Request, Response, MultistreamSelectError};
use futures::{prelude::*, sink, Async, StartSend, try_ready};
use std::io;
use tokio_codec::Encoder;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::{decode, codec::Uvi};
use std::marker;
use unsigned_varint as uvi;
/// The maximum number of supported protocols that can be processed.
const MAX_PROTOCOLS: usize = 1000;
/// Wraps around a `AsyncRead+AsyncWrite`.
/// Assumes that we're on the dialer's side. Produces and accepts messages.
pub struct Dialer<R, N> {
inner: LengthDelimited<R, MessageEncoder<N>>,
handshake_finished: bool
inner: LengthDelimited<R>,
handshake_finished: bool,
_protocol_name: marker::PhantomData<N>,
}
impl<R, N> Dialer<R, N>
@ -45,15 +47,16 @@ where
N: AsRef<[u8]>
{
pub fn dial(inner: R) -> DialerFuture<R, N> {
let codec = MessageEncoder(std::marker::PhantomData);
let sender = LengthDelimited::new(inner, codec);
let io = LengthDelimited::new(inner);
let mut buf = BytesMut::new();
Header::Multistream10.encode(&mut buf);
DialerFuture {
inner: sender.send(Message::Header)
inner: io.send(buf.freeze()),
_protocol_name: marker::PhantomData,
}
}
/// Grants back the socket. Typically used after a `ProtocolAck` has been received.
#[inline]
pub fn into_inner(self) -> R {
self.inner.into_inner()
}
@ -64,24 +67,22 @@ where
R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{
type SinkItem = DialerToListenerMessage<N>;
type SinkItem = Request<N>;
type SinkError = MultistreamSelectError;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
match self.inner.start_send(Message::Body(item))? {
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
AsyncSink::NotReady(Message::Header) => unreachable!(),
AsyncSink::Ready => Ok(AsyncSink::Ready)
fn start_send(&mut self, request: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
let mut msg = BytesMut::new();
request.encode(&mut msg)?;
match self.inner.start_send(msg.freeze())? {
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(request)),
AsyncSink::Ready => Ok(AsyncSink::Ready),
}
}
#[inline]
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.poll_complete()?)
}
#[inline]
fn close(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.close()?)
}
@ -91,20 +92,20 @@ impl<R, N> Stream for Dialer<R, N>
where
R: AsyncRead + AsyncWrite
{
type Item = ListenerToDialerMessage<Bytes>;
type Item = Response<Bytes>;
type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
let mut frame = match self.inner.poll() {
Ok(Async::Ready(Some(frame))) => frame,
let mut msg = match self.inner.poll() {
Ok(Async::Ready(Some(msg))) => msg,
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => return Err(err.into()),
};
if !self.handshake_finished {
if frame == MULTISTREAM_PROTOCOL_WITH_LF {
if msg == MSG_MULTISTREAM_1_0 {
self.handshake_finished = true;
continue;
} else {
@ -112,31 +113,31 @@ where
}
}
if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') {
let frame_len = frame.len();
let protocol = frame.split_to(frame_len - 1);
return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck {
name: protocol
})));
} else if frame == b"na\n"[..] {
return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable)));
if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') {
let len = msg.len();
let name = msg.split_to(len - 1);
return Ok(Async::Ready(Some(
Response::Protocol { name }
)));
} else if msg == MSG_PROTOCOL_NA {
return Ok(Async::Ready(Some(Response::ProtocolNotAvailable)));
} else {
// A varint number of protocols
let (num_protocols, mut remaining) = decode::usize(&frame)?;
if num_protocols > 1000 { // TODO: configurable limit
return Err(MultistreamSelectError::VarintParseError("too many protocols".into()))
let (num_protocols, mut remaining) = uvi::decode::usize(&msg)?;
if num_protocols > MAX_PROTOCOLS { // TODO: configurable limit
return Err(MultistreamSelectError::TooManyProtocols)
}
let mut out = Vec::with_capacity(num_protocols);
let mut protocols = Vec::with_capacity(num_protocols);
for _ in 0 .. num_protocols {
let (len, rem) = decode::usize(remaining)?;
let (len, rem) = uvi::decode::usize(remaining)?;
if len == 0 || len > rem.len() || rem[len - 1] != b'\n' {
return Err(MultistreamSelectError::UnknownMessage)
}
out.push(Bytes::from(&rem[.. len - 1]));
protocols.push(Bytes::from(&rem[.. len - 1]));
remaining = &rem[len ..]
}
return Ok(Async::Ready(Some(
ListenerToDialerMessage::ProtocolsListResponse { list: out },
Response::SupportedProtocols { protocols },
)));
}
}
@ -145,7 +146,8 @@ where
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
pub struct DialerFuture<T: AsyncWrite, N: AsRef<[u8]>> {
inner: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
inner: sink::Send<LengthDelimited<T>>,
_protocol_name: marker::PhantomData<N>,
}
impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
@ -154,57 +156,17 @@ impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let inner = try_ready!(self.inner.poll());
Ok(Async::Ready(Dialer { inner, handshake_finished: false }))
}
}
/// tokio-codec `Encoder` handling `DialerToListenerMessage` values.
struct MessageEncoder<N>(std::marker::PhantomData<N>);
enum Message<N> {
Header,
Body(DialerToListenerMessage<N>)
}
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
type Item = Message<N>;
type Error = MultistreamSelectError;
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Header => {
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
Ok(())
}
Message::Body(DialerToListenerMessage::ProtocolRequest { name }) => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
if len > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
Uvi::<usize>::default().encode(len, dest)?;
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Message::Body(DialerToListenerMessage::ProtocolsListRequest) => {
Uvi::<usize>::default().encode(3, dest)?;
dest.reserve(3);
dest.put(&b"ls\n"[..]);
Ok(())
}
}
Ok(Async::Ready(Dialer {
inner,
handshake_finished: false,
_protocol_name: marker::PhantomData,
}))
}
}
#[cfg(test)]
mod tests {
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
use super::*;
use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream};
use futures::Future;
@ -225,13 +187,13 @@ mod tests {
.from_err()
.and_then(move |stream| Dialer::dial(stream))
.and_then(move |dialer| {
let p = b"invalid_name";
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
let name = b"invalid_name";
dialer.send(Request::Protocol { name })
});
let mut rt = Runtime::new().unwrap();
match rt.block_on(server.join(client)) {
Err(MultistreamSelectError::WrongProtocolName) => (),
Err(MultistreamSelectError::InvalidProtocolName) => (),
_ => panic!(),
}
}

View File

@ -20,7 +20,7 @@
//! Contains the error structs for the low-level protocol handling.
use std::error;
use std::error::Error;
use std::fmt;
use std::io;
use unsigned_varint::decode;
@ -38,29 +38,25 @@ pub enum MultistreamSelectError {
UnknownMessage,
/// Protocol names must always start with `/`, otherwise this error is returned.
WrongProtocolName,
InvalidProtocolName,
/// Failure to parse variable-length integer.
// TODO: we don't include the actual error, because that would remove Send from the enum
VarintParseError(String),
/// Too many protocols have been returned by the remote.
TooManyProtocols,
}
impl From<io::Error> for MultistreamSelectError {
#[inline]
fn from(err: io::Error) -> MultistreamSelectError {
MultistreamSelectError::IoError(err)
}
}
impl From<decode::Error> for MultistreamSelectError {
#[inline]
fn from(err: decode::Error) -> MultistreamSelectError {
MultistreamSelectError::VarintParseError(err.to_string())
Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
}
}
impl error::Error for MultistreamSelectError {
#[inline]
impl Error for MultistreamSelectError {
fn description(&self) -> &str {
match *self {
MultistreamSelectError::IoError(_) => "I/O error",
@ -68,16 +64,15 @@ impl error::Error for MultistreamSelectError {
"the remote doesn't use the same multistream-select protocol as we do"
}
MultistreamSelectError::UnknownMessage => "received an unknown message from the remote",
MultistreamSelectError::WrongProtocolName => {
MultistreamSelectError::InvalidProtocolName => {
"protocol names must always start with `/`, otherwise this error is returned"
}
MultistreamSelectError::VarintParseError(_) => {
"failure to parse variable-length integer"
}
MultistreamSelectError::TooManyProtocols =>
"Too many protocols."
}
}
fn cause(&self) -> Option<&dyn error::Error> {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
MultistreamSelectError::IoError(ref err) => Some(err),
_ => None,
@ -86,8 +81,7 @@ impl error::Error for MultistreamSelectError {
}
impl fmt::Display for MultistreamSelectError {
#[inline]
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(fmt, "{}", error::Error::description(self))
write!(fmt, "{}", Error::description(self))
}
}

View File

@ -20,23 +20,21 @@
//! Contains the `Listener` wrapper, which allows raw communications with a dialer.
use bytes::{BufMut, Bytes, BytesMut};
use super::*;
use bytes::{Bytes, BytesMut};
use crate::length_delimited::LengthDelimited;
use crate::protocol::DialerToListenerMessage;
use crate::protocol::ListenerToDialerMessage;
use crate::protocol::MultistreamSelectError;
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use crate::protocol::{Request, Response, MultistreamSelectError};
use futures::{prelude::*, sink, stream::StreamFuture};
use log::{debug, trace};
use std::{io, mem};
use tokio_codec::Encoder;
use std::{marker, mem};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::{encode, codec::Uvi};
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
/// accepts messages.
pub struct Listener<R, N> {
inner: LengthDelimited<R, MessageEncoder<N>>
inner: LengthDelimited<R>,
_protocol_name: marker::PhantomData<N>,
}
impl<R, N> Listener<R, N>
@ -47,16 +45,15 @@ where
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
/// future returns a `Listener`.
pub fn listen(inner: R) -> ListenerFuture<R, N> {
let codec = MessageEncoder(std::marker::PhantomData);
let inner = LengthDelimited::new(inner, codec);
let inner = LengthDelimited::new(inner);
ListenerFuture {
inner: ListenerFutureState::Await { inner: inner.into_future() }
inner: ListenerFutureState::Await { inner: inner.into_future() },
_protocol_name: marker::PhantomData,
}
}
/// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a
/// `ProtocolAck` has been sent back.
#[inline]
pub fn into_inner(self) -> R {
self.inner.into_inner()
}
@ -67,24 +64,22 @@ where
R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{
type SinkItem = ListenerToDialerMessage<N>;
type SinkItem = Response<N>;
type SinkError = MultistreamSelectError;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
match self.inner.start_send(Message::Body(item))? {
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
AsyncSink::NotReady(Message::Header) => unreachable!(),
fn start_send(&mut self, response: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
let mut msg = BytesMut::new();
response.encode(&mut msg)?;
match self.inner.start_send(msg.freeze())? {
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(response)),
AsyncSink::Ready => Ok(AsyncSink::Ready)
}
}
#[inline]
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.poll_complete()?)
}
#[inline]
fn close(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.close()?)
}
@ -94,26 +89,26 @@ impl<R, N> Stream for Listener<R, N>
where
R: AsyncRead + AsyncWrite,
{
type Item = DialerToListenerMessage<Bytes>;
type Item = Request<Bytes>;
type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
let mut frame = match self.inner.poll() {
Ok(Async::Ready(Some(frame))) => frame,
let mut msg = match self.inner.poll() {
Ok(Async::Ready(Some(msg))) => msg,
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => return Err(err.into()),
};
if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') {
let frame_len = frame.len();
let protocol = frame.split_to(frame_len - 1);
if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') {
let len = msg.len();
let name = msg.split_to(len - 1);
Ok(Async::Ready(Some(
DialerToListenerMessage::ProtocolRequest { name: protocol },
Request::Protocol { name },
)))
} else if frame == b"ls\n"[..] {
} else if msg == MSG_LS {
Ok(Async::Ready(Some(
DialerToListenerMessage::ProtocolsListRequest,
Request::ListProtocols,
)))
} else {
Err(MultistreamSelectError::UnknownMessage)
@ -124,16 +119,17 @@ where
/// Future, returned by `Listener::new` which performs the handshake and returns
/// the `Listener` if successful.
pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
inner: ListenerFutureState<T, N>
pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N> {
inner: ListenerFutureState<T>,
_protocol_name: marker::PhantomData<N>,
}
enum ListenerFutureState<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
enum ListenerFutureState<T: AsyncRead + AsyncWrite> {
Await {
inner: StreamFuture<LengthDelimited<T, MessageEncoder<N>>>
inner: StreamFuture<LengthDelimited<T>>
},
Reply {
sender: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
sender: sink::Send<LengthDelimited<T>>
},
Undefined
}
@ -155,12 +151,14 @@ impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N>
}
Err((e, _)) => return Err(MultistreamSelectError::from(e))
};
if msg.as_ref().map(|b| &b[..]) != Some(MULTISTREAM_PROTOCOL_WITH_LF) {
debug!("failed handshake; received: {:?}", msg);
if msg.as_ref().map(|b| &b[..]) != Some(MSG_MULTISTREAM_1_0) {
debug!("Unexpected message: {:?}", msg);
return Err(MultistreamSelectError::FailedHandshake)
}
trace!("sending back /multistream/<version> to finish the handshake");
let sender = socket.send(Message::Header);
let mut frame = BytesMut::new();
Header::Multistream10.encode(&mut frame);
let sender = socket.send(frame.freeze());
self.inner = ListenerFutureState::Reply { sender }
}
ListenerFutureState::Reply { mut sender } => {
@ -171,70 +169,13 @@ impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N>
return Ok(Async::NotReady)
}
};
return Ok(Async::Ready(Listener { inner: listener }))
return Ok(Async::Ready(Listener {
inner: listener,
_protocol_name: marker::PhantomData
}))
}
ListenerFutureState::Undefined => panic!("ListenerFutureState::poll called after completion")
}
}
}
}
/// tokio-codec `Encoder` handling `ListenerToDialerMessage` values.
struct MessageEncoder<N>(std::marker::PhantomData<N>);
enum Message<N> {
Header,
Body(ListenerToDialerMessage<N>)
}
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
type Item = Message<N>;
type Error = MultistreamSelectError;
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Header => {
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
Ok(())
}
Message::Body(ListenerToDialerMessage::ProtocolAck { name }) => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
if len > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
Uvi::<usize>::default().encode(len, dest)?;
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Message::Body(ListenerToDialerMessage::ProtocolsListResponse { list }) => {
let mut buf = encode::usize_buffer();
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
for e in &list {
if e.as_ref().len() + 1 > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
out_msg.extend(encode::usize(e.as_ref().len() + 1, &mut buf)); // +1 for '\n'
out_msg.extend_from_slice(e.as_ref());
out_msg.push(b'\n')
}
let len = encode::usize(out_msg.len(), &mut buf);
dest.reserve(len.len() + out_msg.len());
dest.put(len);
dest.put(out_msg);
Ok(())
}
Message::Body(ListenerToDialerMessage::NotAvailable) => {
Uvi::<usize>::default().encode(3, dest)?;
dest.reserve(3);
dest.put(&b"na\n"[..]);
Ok(())
ListenerFutureState::Undefined =>
panic!("ListenerFutureState::poll called after completion")
}
}
}
@ -242,12 +183,12 @@ impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
#[cfg(test)]
mod tests {
use super::*;
use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream};
use bytes::Bytes;
use futures::Future;
use futures::{Sink, Stream};
use crate::protocol::{Dialer, Listener, ListenerToDialerMessage, MultistreamSelectError};
#[test]
fn wrong_proto_name() {
@ -260,8 +201,8 @@ mod tests {
.map_err(|(e, _)| e.into())
.and_then(move |(connec, _)| Listener::listen(connec.unwrap()))
.and_then(|listener| {
let proto_name = Bytes::from("invalid-proto");
listener.send(ListenerToDialerMessage::ProtocolAck { name: proto_name })
let name = Bytes::from("invalid-proto");
listener.send(Response::Protocol { name })
});
let client = TcpStream::connect(&listener_addr)
@ -270,7 +211,7 @@ mod tests {
let mut rt = Runtime::new().unwrap();
match rt.block_on(server.join(client)) {
Err(MultistreamSelectError::WrongProtocolName) => (),
Err(MultistreamSelectError::InvalidProtocolName) => (),
_ => panic!(),
}
}

View File

@ -20,47 +20,125 @@
//! Contains lower-level structs to handle the multistream protocol.
const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
const MSG_PROTOCOL_NA: &[u8] = b"na\n";
const MSG_LS: &[u8] = b"ls\n";
mod dialer;
mod error;
mod listener;
const MULTISTREAM_PROTOCOL_WITH_LF: &[u8] = b"/multistream/1.0.0\n";
pub use self::dialer::{Dialer, DialerFuture};
pub use self::error::MultistreamSelectError;
pub use self::listener::{Listener, ListenerFuture};
use bytes::{BytesMut, BufMut};
use unsigned_varint as uvi;
pub enum Header {
Multistream10
}
impl Header {
fn encode(&self, dest: &mut BytesMut) {
match self {
Header::Multistream10 => {
dest.reserve(MSG_MULTISTREAM_1_0.len());
dest.put(MSG_MULTISTREAM_1_0);
}
}
}
}
/// Message sent from the dialer to the listener.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DialerToListenerMessage<N> {
pub enum Request<N> {
/// The dialer wants us to use a protocol.
///
/// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start
/// communicating in the new protocol.
ProtocolRequest {
Protocol {
/// Name of the protocol.
name: N
},
/// The dialer requested the list of protocols that the listener supports.
ProtocolsListRequest,
ListProtocols,
}
impl<N: AsRef<[u8]>> Request<N> {
fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> {
match self {
Request::Protocol { name } => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::InvalidProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Request::ListProtocols => {
dest.reserve(MSG_LS.len());
dest.put(MSG_LS);
Ok(())
}
}
}
}
/// Message sent from the listener to the dialer.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ListenerToDialerMessage<N> {
pub enum Response<N> {
/// The protocol requested by the dialer is accepted. The socket immediately starts using the
/// new protocol.
ProtocolAck { name: N },
Protocol { name: N },
/// The protocol requested by the dialer is not supported or available.
NotAvailable,
ProtocolNotAvailable,
/// Response to the request for the list of protocols.
ProtocolsListResponse {
SupportedProtocols {
/// The list of protocols.
// TODO: use some sort of iterator
list: Vec<N>,
protocols: Vec<N>,
},
}
impl<N: AsRef<[u8]>> Response<N> {
fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> {
match self {
Response::Protocol { name } => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::InvalidProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Response::SupportedProtocols { protocols } => {
let mut buf = uvi::encode::usize_buffer();
let mut out_msg = Vec::from(uvi::encode::usize(protocols.len(), &mut buf));
for p in protocols {
out_msg.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n'
out_msg.extend_from_slice(p.as_ref());
out_msg.push(b'\n')
}
dest.reserve(out_msg.len());
dest.put(out_msg);
Ok(())
}
Response::ProtocolNotAvailable => {
dest.reserve(MSG_PROTOCOL_NA.len());
dest.put(MSG_PROTOCOL_NA);
Ok(())
}
}
}
}

View File

@ -24,7 +24,7 @@
use crate::ProtocolChoiceError;
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage};
use crate::protocol::{Dialer, Request, Listener, Response};
use crate::{dialer_select_proto, listener_select_proto};
use futures::prelude::*;
use tokio::runtime::current_thread::Runtime;
@ -56,23 +56,23 @@ fn negotiate_with_self_succeeds() {
.and_then(|l| l.into_future().map_err(|(e, _)| e))
.and_then(|(msg, rest)| {
let proto = match msg {
Some(DialerToListenerMessage::ProtocolRequest { name }) => name,
Some(Request::Protocol { name }) => name,
_ => panic!(),
};
rest.send(ListenerToDialerMessage::ProtocolAck { name: proto })
rest.send(Response::Protocol { name: proto })
});
let client = TcpStream::connect(&listener_addr)
.from_err()
.and_then(move |stream| Dialer::dial(stream))
.and_then(move |dialer| {
let p = b"/hello/1.0.0";
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
let name = b"/hello/1.0.0";
dialer.send(Request::Protocol { name })
})
.and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e))
.and_then(move |(msg, _)| {
let proto = match msg {
Some(ListenerToDialerMessage::ProtocolAck { name }) => name,
Some(Response::Protocol { name }) => name,
_ => panic!(),
};
assert_eq!(proto, "/hello/1.0.0");