mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-06-01 20:21:21 +00:00
Remove tokio-codec dependency from multistream-select. (#1203)
* Remove tokio-codec dependency from multistream-select. In preparation for the eventual switch from tokio to std futures. Includes some initial refactoring in preparation for further work in the context of https://github.com/libp2p/rust-libp2p/issues/659. * Reduce default buffer sizes. * Allow more than one frame to be buffered for sending. * Doc tweaks. * Remove superfluous (duplicated) Message types.
This commit is contained in:
parent
bcc7c4d349
commit
2fd941122a
@ -14,9 +14,8 @@ bytes = "0.4"
|
||||
futures = { version = "0.1" }
|
||||
log = "0.4"
|
||||
smallvec = "0.6"
|
||||
tokio-codec = "0.1"
|
||||
tokio-io = "0.1"
|
||||
unsigned-varint = { version = "0.2.1", features = ["codec"] }
|
||||
unsigned-varint = { version = "0.2.2" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = "0.1"
|
||||
|
@ -22,19 +22,14 @@
|
||||
//! `multistream-select` for the dialer.
|
||||
|
||||
use futures::{future::Either, prelude::*, stream::StreamFuture};
|
||||
use crate::protocol::{
|
||||
Dialer,
|
||||
DialerFuture,
|
||||
DialerToListenerMessage,
|
||||
ListenerToDialerMessage
|
||||
};
|
||||
use crate::protocol::{Dialer, DialerFuture, Request, Response};
|
||||
use log::trace;
|
||||
use std::mem;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use crate::{Negotiated, ProtocolChoiceError};
|
||||
|
||||
/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer
|
||||
/// either sequentially of by considering all protocols in parallel.
|
||||
/// either sequentially or by considering all protocols in parallel.
|
||||
pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;
|
||||
|
||||
/// Helps selecting a protocol amongst the ones supported.
|
||||
@ -75,7 +70,10 @@ where
|
||||
{
|
||||
let protocols = protocols.into_iter();
|
||||
DialerSelectSeq {
|
||||
inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::dial(inner), protocols }
|
||||
inner: DialerSelectSeqState::AwaitDialer {
|
||||
dialer_fut: Dialer::dial(inner),
|
||||
protocols
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,9 +146,7 @@ where
|
||||
}
|
||||
DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => {
|
||||
trace!("sending {:?}", proto_name.as_ref());
|
||||
let req = DialerToListenerMessage::ProtocolRequest {
|
||||
name: proto_name.clone()
|
||||
};
|
||||
let req = Request::Protocol { name: proto_name.clone() };
|
||||
match dialer.start_send(req)? {
|
||||
AsyncSink::Ready => {
|
||||
self.inner = DialerSelectSeqState::FlushProtocol {
|
||||
@ -204,12 +200,12 @@ where
|
||||
};
|
||||
trace!("received {:?}", m);
|
||||
match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? {
|
||||
ListenerToDialerMessage::ProtocolAck { ref name }
|
||||
Response::Protocol { ref name }
|
||||
if name.as_ref() == proto_name.as_ref() =>
|
||||
{
|
||||
return Ok(Async::Ready((proto_name, Negotiated(r.into_inner()))))
|
||||
}
|
||||
ListenerToDialerMessage::NotAvailable => {
|
||||
Response::ProtocolNotAvailable => {
|
||||
let proto_name = protocols.next()
|
||||
.ok_or(ProtocolChoiceError::NoProtocolFound)?;
|
||||
self.inner = DialerSelectSeqState::NextProtocol {
|
||||
@ -244,9 +240,8 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in
|
||||
/// parellel, by first requesting the liste of protocols supported by the remote endpoint and
|
||||
/// parallel, by first requesting the list of protocols supported by the remote endpoint and
|
||||
/// then selecting the most appropriate one by applying a match predicate to the result.
|
||||
pub struct DialerSelectPar<R, I>
|
||||
where
|
||||
@ -319,7 +314,7 @@ where
|
||||
}
|
||||
DialerSelectParState::ProtocolList { mut dialer, protocols } => {
|
||||
trace!("requesting protocols list");
|
||||
match dialer.start_send(DialerToListenerMessage::ProtocolsListRequest)? {
|
||||
match dialer.start_send(Request::ListProtocols)? {
|
||||
AsyncSink::Ready => {
|
||||
self.inner = DialerSelectParState::FlushListRequest {
|
||||
dialer,
|
||||
@ -359,15 +354,15 @@ where
|
||||
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
|
||||
};
|
||||
trace!("protocols list response: {:?}", resp);
|
||||
let list =
|
||||
if let Some(ListenerToDialerMessage::ProtocolsListResponse { list }) = resp {
|
||||
list
|
||||
let supported =
|
||||
if let Some(Response::SupportedProtocols { protocols }) = resp {
|
||||
protocols
|
||||
} else {
|
||||
return Err(ProtocolChoiceError::UnexpectedMessage)
|
||||
};
|
||||
let mut found = None;
|
||||
for local_name in protocols {
|
||||
for remote_name in &list {
|
||||
for remote_name in &supported {
|
||||
if remote_name.as_ref() == local_name.as_ref() {
|
||||
found = Some(local_name);
|
||||
break;
|
||||
@ -381,10 +376,8 @@ where
|
||||
self.inner = DialerSelectParState::Protocol { dialer, proto_name }
|
||||
}
|
||||
DialerSelectParState::Protocol { mut dialer, proto_name } => {
|
||||
trace!("requesting protocol: {:?}", proto_name.as_ref());
|
||||
let req = DialerToListenerMessage::ProtocolRequest {
|
||||
name: proto_name.clone()
|
||||
};
|
||||
trace!("Requesting protocol: {:?}", proto_name.as_ref());
|
||||
let req = Request::Protocol { name: proto_name.clone() };
|
||||
match dialer.start_send(req)? {
|
||||
AsyncSink::Ready => {
|
||||
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }
|
||||
@ -420,7 +413,7 @@ where
|
||||
};
|
||||
trace!("received {:?}", resp);
|
||||
match resp {
|
||||
Some(ListenerToDialerMessage::ProtocolAck { ref name })
|
||||
Some(Response::Protocol { ref name })
|
||||
if name.as_ref() == proto_name.as_ref() =>
|
||||
{
|
||||
return Ok(Async::Ready((proto_name, Negotiated(dialer.into_inner()))))
|
||||
|
@ -21,9 +21,8 @@
|
||||
//! Main `ProtocolChoiceError` error.
|
||||
|
||||
use crate::protocol::MultistreamSelectError;
|
||||
use std::error;
|
||||
use std::fmt;
|
||||
use std::io::Error as IoError;
|
||||
use std::error::Error;
|
||||
use std::{fmt, io};
|
||||
|
||||
/// Error that can happen when negotiating a protocol with the remote.
|
||||
#[derive(Debug)]
|
||||
@ -39,21 +38,18 @@ pub enum ProtocolChoiceError {
|
||||
}
|
||||
|
||||
impl From<MultistreamSelectError> for ProtocolChoiceError {
|
||||
#[inline]
|
||||
fn from(err: MultistreamSelectError) -> ProtocolChoiceError {
|
||||
ProtocolChoiceError::MultistreamSelectError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<IoError> for ProtocolChoiceError {
|
||||
#[inline]
|
||||
fn from(err: IoError) -> ProtocolChoiceError {
|
||||
impl From<io::Error> for ProtocolChoiceError {
|
||||
fn from(err: io::Error) -> ProtocolChoiceError {
|
||||
MultistreamSelectError::from(err).into()
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for ProtocolChoiceError {
|
||||
#[inline]
|
||||
impl Error for ProtocolChoiceError {
|
||||
fn description(&self) -> &str {
|
||||
match *self {
|
||||
ProtocolChoiceError::MultistreamSelectError(_) => "error in the protocol",
|
||||
@ -66,7 +62,7 @@ impl error::Error for ProtocolChoiceError {
|
||||
}
|
||||
}
|
||||
|
||||
fn cause(&self) -> Option<&dyn error::Error> {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match *self {
|
||||
ProtocolChoiceError::MultistreamSelectError(ref err) => Some(err),
|
||||
_ => None,
|
||||
@ -75,8 +71,7 @@ impl error::Error for ProtocolChoiceError {
|
||||
}
|
||||
|
||||
impl fmt::Display for ProtocolChoiceError {
|
||||
#[inline]
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||
write!(fmt, "{}", error::Error::description(self))
|
||||
write!(fmt, "{}", Error::description(self))
|
||||
}
|
||||
}
|
||||
|
@ -18,78 +18,81 @@
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{Async, Poll, Sink, StartSend, Stream};
|
||||
use smallvec::SmallVec;
|
||||
use bytes::{Bytes, BytesMut, BufMut};
|
||||
use futures::{try_ready, Async, Poll, Sink, StartSend, Stream, AsyncSink};
|
||||
use std::{io, u16};
|
||||
use tokio_codec::{Encoder, FramedWrite};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use unsigned_varint::decode;
|
||||
use unsigned_varint as uvi;
|
||||
|
||||
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
|
||||
const MAX_LEN_BYTES: u16 = 2;
|
||||
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
|
||||
const DEFAULT_BUFFER_SIZE: usize = 64;
|
||||
|
||||
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` resource to read
|
||||
/// and write unsigned-varint prefixed frames.
|
||||
///
|
||||
/// We purposely only support a frame length of under 64kiB. Frames mostly consist
|
||||
/// in a short protocol name, which is highly unlikely to be more than 64kiB long.
|
||||
pub struct LengthDelimited<R, C> {
|
||||
// The inner socket where data is pulled from.
|
||||
inner: FramedWrite<R, C>,
|
||||
// Intermediary buffer where we put either the length of the next frame of data, or the frame
|
||||
// of data itself before it is returned.
|
||||
// Must always contain enough space to read data from `inner`.
|
||||
internal_buffer: SmallVec<[u8; 64]>,
|
||||
// Number of bytes within `internal_buffer` that contain valid data.
|
||||
internal_buffer_pos: usize,
|
||||
// State of the decoder.
|
||||
state: State
|
||||
/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
|
||||
/// frame length). Frames mostly consist in a short protocol name, which is highly
|
||||
/// unlikely to be more than 16KiB long.
|
||||
pub struct LengthDelimited<R> {
|
||||
/// The inner I/O resource.
|
||||
inner: R,
|
||||
/// Read buffer for a single incoming unsigned-varint length-delimited frame.
|
||||
read_buffer: BytesMut,
|
||||
/// Write buffer for outgoing unsigned-varint length-delimited frames.
|
||||
write_buffer: BytesMut,
|
||||
/// The current read state, alternating between reading a frame
|
||||
/// length and reading a frame payload.
|
||||
read_state: ReadState,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
enum State {
|
||||
// We are currently reading the length of the next frame of data.
|
||||
ReadingLength,
|
||||
// We are currently reading the frame of data itself.
|
||||
ReadingData { frame_len: u16 },
|
||||
enum ReadState {
|
||||
/// We are currently reading the length of the next frame of data.
|
||||
ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize },
|
||||
/// We are currently reading the frame of data itself.
|
||||
ReadData { len: u16, pos: usize },
|
||||
}
|
||||
|
||||
impl<R, C> LengthDelimited<R, C>
|
||||
where
|
||||
R: AsyncWrite,
|
||||
C: Encoder
|
||||
{
|
||||
pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
|
||||
impl Default for ReadState {
|
||||
fn default() -> Self {
|
||||
ReadState::ReadLength {
|
||||
buf: [0; MAX_LEN_BYTES as usize],
|
||||
pos: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> LengthDelimited<R> {
|
||||
/// Creates a new I/O resource for reading and writing unsigned-varint
|
||||
/// length delimited frames.
|
||||
pub fn new(inner: R) -> LengthDelimited<R> {
|
||||
LengthDelimited {
|
||||
inner: FramedWrite::new(inner, codec),
|
||||
internal_buffer: {
|
||||
let mut v = SmallVec::new();
|
||||
v.push(0);
|
||||
v
|
||||
},
|
||||
internal_buffer_pos: 0,
|
||||
state: State::ReadingLength
|
||||
inner,
|
||||
read_state: ReadState::default(),
|
||||
read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
|
||||
write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
|
||||
}
|
||||
}
|
||||
|
||||
/// Destroys the `LengthDelimited` and returns the underlying socket.
|
||||
///
|
||||
/// Contrary to its equivalent `tokio_io::codec::length_delimited::FramedRead`, this method is
|
||||
/// guaranteed not to skip any data from the socket.
|
||||
/// This method is guaranteed not to skip any data from the socket.
|
||||
///
|
||||
/// # Panic
|
||||
///
|
||||
/// Will panic if called while there is data inside the buffer. **This can only happen if
|
||||
/// you call `poll()` manually**. Using this struct as it is intended to be used (i.e. through
|
||||
/// the modifiers provided by the `futures` crate) will always leave the object in a state in
|
||||
/// which `into_inner()` will not panic.
|
||||
#[inline]
|
||||
/// Will panic if called while there is data inside the read or write buffer.
|
||||
/// **This can only happen if you call `poll()` manually**. Using this struct
|
||||
/// as it is intended to be used (i.e. through the high-level `futures` API)
|
||||
/// will always leave the object in a state in which `into_inner()` will not panic.
|
||||
pub fn into_inner(self) -> R {
|
||||
assert_eq!(self.state, State::ReadingLength);
|
||||
assert_eq!(self.internal_buffer_pos, 0);
|
||||
self.inner.into_inner()
|
||||
assert!(self.write_buffer.is_empty());
|
||||
assert!(self.read_buffer.is_empty());
|
||||
self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, C> Stream for LengthDelimited<R, C>
|
||||
impl<R> Stream for LengthDelimited<R>
|
||||
where
|
||||
R: AsyncRead
|
||||
{
|
||||
@ -98,16 +101,11 @@ where
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
loop {
|
||||
debug_assert!(!self.internal_buffer.is_empty());
|
||||
debug_assert!(self.internal_buffer_pos < self.internal_buffer.len());
|
||||
|
||||
match self.state {
|
||||
State::ReadingLength => {
|
||||
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
|
||||
match self.inner.get_mut().read(slice) {
|
||||
match &mut self.read_state {
|
||||
ReadState::ReadLength { buf, pos } => {
|
||||
match self.inner.read(&mut buf[*pos .. *pos + 1]) {
|
||||
Ok(0) => {
|
||||
// EOF
|
||||
if self.internal_buffer_pos == 0 {
|
||||
if *pos == 0 {
|
||||
return Ok(Async::Ready(None));
|
||||
} else {
|
||||
return Err(io::ErrorKind::UnexpectedEof.into());
|
||||
@ -115,7 +113,7 @@ where
|
||||
}
|
||||
Ok(n) => {
|
||||
debug_assert_eq!(n, 1);
|
||||
self.internal_buffer_pos += n;
|
||||
*pos += n;
|
||||
}
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||
return Ok(Async::NotReady);
|
||||
@ -125,56 +123,45 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
debug_assert_eq!(self.internal_buffer.len(), self.internal_buffer_pos);
|
||||
|
||||
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 {
|
||||
// End of length prefix. Most of the time we will switch to reading data,
|
||||
// but we need to handle a few corner cases first.
|
||||
let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| {
|
||||
if (buf[*pos - 1] & 0x80) == 0 {
|
||||
// MSB is not set, indicating the end of the length prefix.
|
||||
let (len, _) = uvi::decode::u16(buf).map_err(|e| {
|
||||
log::debug!("invalid length prefix: {}", e);
|
||||
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
|
||||
})?;
|
||||
|
||||
if frame_len >= 1 {
|
||||
self.state = State::ReadingData { frame_len };
|
||||
self.internal_buffer.clear();
|
||||
self.internal_buffer.reserve(frame_len as usize);
|
||||
self.internal_buffer.extend((0..frame_len).map(|_| 0));
|
||||
self.internal_buffer_pos = 0;
|
||||
if len >= 1 {
|
||||
self.read_state = ReadState::ReadData { len, pos: 0 };
|
||||
self.read_buffer.resize(len as usize, 0);
|
||||
} else {
|
||||
debug_assert_eq!(frame_len, 0);
|
||||
self.state = State::ReadingLength;
|
||||
self.internal_buffer.clear();
|
||||
self.internal_buffer.push(0);
|
||||
self.internal_buffer_pos = 0;
|
||||
return Ok(Async::Ready(Some(From::from(&[][..]))));
|
||||
debug_assert_eq!(len, 0);
|
||||
self.read_state = ReadState::default();
|
||||
return Ok(Async::Ready(Some(Bytes::new())));
|
||||
}
|
||||
} else if self.internal_buffer_pos >= 2 {
|
||||
// Length prefix is too long. See module doc for info about max frame len.
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long"));
|
||||
} else {
|
||||
// Prepare for next read.
|
||||
self.internal_buffer.push(0);
|
||||
} else if *pos == MAX_LEN_BYTES as usize {
|
||||
// MSB signals more length bytes but we have already read the maximum.
|
||||
// See the module documentation about the max frame len.
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Maximum frame length exceeded"));
|
||||
}
|
||||
}
|
||||
State::ReadingData { frame_len } => {
|
||||
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
|
||||
match self.inner.get_mut().read(slice) {
|
||||
ReadState::ReadData { len, pos } => {
|
||||
match self.inner.read(&mut self.read_buffer[*pos..]) {
|
||||
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
|
||||
Ok(n) => self.internal_buffer_pos += n,
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
Err(err) => return Err(err)
|
||||
Ok(n) => *pos += n,
|
||||
Err(err) =>
|
||||
if err.kind() == io::ErrorKind::WouldBlock {
|
||||
return Ok(Async::NotReady)
|
||||
} else {
|
||||
return Err(err)
|
||||
}
|
||||
};
|
||||
if self.internal_buffer_pos >= frame_len as usize {
|
||||
// Finished reading the frame of data.
|
||||
self.state = State::ReadingLength;
|
||||
let out_data = From::from(&self.internal_buffer[..]);
|
||||
self.internal_buffer.clear();
|
||||
self.internal_buffer.push(0);
|
||||
self.internal_buffer_pos = 0;
|
||||
return Ok(Async::Ready(Some(out_data)));
|
||||
if *pos == *len as usize {
|
||||
// Finished reading the frame.
|
||||
let frame = self.read_buffer.split_off(0).freeze();
|
||||
self.read_state = ReadState::default();
|
||||
return Ok(Async::Ready(Some(frame)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -182,27 +169,60 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<R, C> Sink for LengthDelimited<R, C>
|
||||
impl<R> Sink for LengthDelimited<R>
|
||||
where
|
||||
R: AsyncWrite,
|
||||
C: Encoder
|
||||
{
|
||||
type SinkItem = <FramedWrite<R, C> as Sink>::SinkItem;
|
||||
type SinkError = <FramedWrite<R, C> as Sink>::SinkError;
|
||||
type SinkItem = Bytes;
|
||||
type SinkError = io::Error;
|
||||
|
||||
#[inline]
|
||||
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||
self.inner.start_send(item)
|
||||
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||
// Use the maximum frame length also as a (soft) upper limit
|
||||
// for the entire write buffer. The actual (hard) limit is thus
|
||||
// implied to be roughly 2 * MAX_FRAME_SIZE.
|
||||
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
|
||||
self.poll_complete()?;
|
||||
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
|
||||
return Ok(AsyncSink::NotReady(msg))
|
||||
}
|
||||
}
|
||||
|
||||
let len = msg.len() as u16;
|
||||
if len > MAX_FRAME_SIZE {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"Maximum frame size exceeded."))
|
||||
}
|
||||
|
||||
let mut uvi_buf = uvi::encode::u16_buffer();
|
||||
let uvi_len = uvi::encode::u16(len, &mut uvi_buf);
|
||||
self.write_buffer.reserve(len as usize + uvi_len.len());
|
||||
self.write_buffer.put(uvi_len);
|
||||
self.write_buffer.put(msg);
|
||||
|
||||
Ok(AsyncSink::Ready)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
||||
self.inner.poll_complete()
|
||||
while !self.write_buffer.is_empty() {
|
||||
let n = try_ready!(self.inner.poll_write(&self.write_buffer));
|
||||
|
||||
if n == 0 {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::WriteZero,
|
||||
"Failed to write buffered frame."))
|
||||
}
|
||||
|
||||
let _ = self.write_buffer.split_to(n);
|
||||
}
|
||||
|
||||
try_ready!(self.inner.poll_flush());
|
||||
return Ok(Async::Ready(()));
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
||||
self.inner.close()
|
||||
try_ready!(self.poll_complete());
|
||||
Ok(self.inner.shutdown()?)
|
||||
}
|
||||
}
|
||||
|
||||
@ -211,12 +231,11 @@ mod tests {
|
||||
use futures::{Future, Stream};
|
||||
use crate::length_delimited::LengthDelimited;
|
||||
use std::io::{Cursor, ErrorKind};
|
||||
use unsigned_varint::codec::UviBytes;
|
||||
|
||||
#[test]
|
||||
fn basic_read() {
|
||||
let data = vec![6, 9, 8, 7, 6, 5, 4];
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = framed.collect().wait().unwrap();
|
||||
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
|
||||
}
|
||||
@ -224,7 +243,7 @@ mod tests {
|
||||
#[test]
|
||||
fn basic_read_two() {
|
||||
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = framed.collect().wait().unwrap();
|
||||
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
|
||||
}
|
||||
@ -236,7 +255,7 @@ mod tests {
|
||||
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
|
||||
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
|
||||
data.extend(frame.clone().into_iter());
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = framed
|
||||
.into_future()
|
||||
.map(|(m, _)| m)
|
||||
@ -250,7 +269,7 @@ mod tests {
|
||||
fn packet_len_too_long() {
|
||||
let mut data = vec![0x81, 0x81, 0x1];
|
||||
data.extend((0..16513).map(|_| 0));
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = framed
|
||||
.into_future()
|
||||
.map(|(m, _)| m)
|
||||
@ -267,7 +286,7 @@ mod tests {
|
||||
#[test]
|
||||
fn empty_frames() {
|
||||
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = framed.collect().wait().unwrap();
|
||||
assert_eq!(
|
||||
recved,
|
||||
@ -284,7 +303,7 @@ mod tests {
|
||||
#[test]
|
||||
fn unexpected_eof_in_len() {
|
||||
let data = vec![0x89];
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = framed.collect().wait();
|
||||
if let Err(io_err) = recved {
|
||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||
@ -296,7 +315,7 @@ mod tests {
|
||||
#[test]
|
||||
fn unexpected_eof_in_data() {
|
||||
let data = vec![5];
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = framed.collect().wait();
|
||||
if let Err(io_err) = recved {
|
||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||
@ -308,7 +327,7 @@ mod tests {
|
||||
#[test]
|
||||
fn unexpected_eof_in_data2() {
|
||||
let data = vec![5, 9, 8, 7];
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let framed = LengthDelimited::new(Cursor::new(data));
|
||||
let recved = framed.collect().wait();
|
||||
if let Err(io_err) = recved {
|
||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||
|
@ -21,7 +21,7 @@
|
||||
//! # Multistream-select
|
||||
//!
|
||||
//! This crate implements the `multistream-select` protocol, which is the protocol used by libp2p
|
||||
//! to negotiate which protocol to use with the remote.
|
||||
//! to negotiate which protocol to use with the remote on a connection or substream.
|
||||
//!
|
||||
//! > **Note**: This crate is used by the internals of *libp2p*, and it is not required to
|
||||
//! > understand it in order to use *libp2p*.
|
||||
@ -76,6 +76,7 @@ mod protocol;
|
||||
|
||||
use futures::prelude::*;
|
||||
use std::io;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
|
||||
pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture};
|
||||
pub use self::error::ProtocolChoiceError;
|
||||
@ -93,9 +94,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<TInner> tokio_io::AsyncRead for Negotiated<TInner>
|
||||
impl<TInner> AsyncRead for Negotiated<TInner>
|
||||
where
|
||||
TInner: tokio_io::AsyncRead
|
||||
TInner: AsyncRead
|
||||
{
|
||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||
self.0.prepare_uninitialized_buffer(buf)
|
||||
@ -119,9 +120,9 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<TInner> tokio_io::AsyncWrite for Negotiated<TInner>
|
||||
impl<TInner> AsyncWrite for Negotiated<TInner>
|
||||
where
|
||||
TInner: tokio_io::AsyncWrite
|
||||
TInner: AsyncWrite
|
||||
{
|
||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||
self.0.shutdown()
|
||||
|
@ -23,10 +23,10 @@
|
||||
|
||||
use futures::{prelude::*, sink, stream::StreamFuture};
|
||||
use crate::protocol::{
|
||||
DialerToListenerMessage,
|
||||
Request,
|
||||
Response,
|
||||
Listener,
|
||||
ListenerFuture,
|
||||
ListenerToDialerMessage
|
||||
};
|
||||
use log::{debug, trace};
|
||||
use std::mem;
|
||||
@ -126,13 +126,13 @@ where
|
||||
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
|
||||
};
|
||||
match msg {
|
||||
Some(DialerToListenerMessage::ProtocolsListRequest) => {
|
||||
Some(Request::ListProtocols) => {
|
||||
trace!("protocols list response: {:?}", protocols
|
||||
.into_iter()
|
||||
.map(|p| p.as_ref().into())
|
||||
.collect::<Vec<Vec<u8>>>());
|
||||
let list = protocols.into_iter().collect();
|
||||
let msg = ListenerToDialerMessage::ProtocolsListResponse { list };
|
||||
let supported = protocols.into_iter().collect();
|
||||
let msg = Response::SupportedProtocols { protocols: supported };
|
||||
let sender = listener.send(msg);
|
||||
self.inner = ListenerSelectState::Outgoing {
|
||||
sender,
|
||||
@ -140,12 +140,12 @@ where
|
||||
outcome: None
|
||||
}
|
||||
}
|
||||
Some(DialerToListenerMessage::ProtocolRequest { name }) => {
|
||||
Some(Request::Protocol { name }) => {
|
||||
let mut outcome = None;
|
||||
let mut send_back = ListenerToDialerMessage::NotAvailable;
|
||||
let mut send_back = Response::ProtocolNotAvailable;
|
||||
for supported in &protocols {
|
||||
if name.as_ref() == supported.as_ref() {
|
||||
send_back = ListenerToDialerMessage::ProtocolAck {
|
||||
send_back = Response::Protocol {
|
||||
name: supported.clone()
|
||||
};
|
||||
outcome = Some(supported);
|
||||
|
@ -20,23 +20,25 @@
|
||||
|
||||
//! Contains the `Dialer` wrapper, which allows raw communications with a listener.
|
||||
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use super::*;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use crate::length_delimited::LengthDelimited;
|
||||
use crate::protocol::DialerToListenerMessage;
|
||||
use crate::protocol::ListenerToDialerMessage;
|
||||
use crate::protocol::MultistreamSelectError;
|
||||
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
|
||||
use crate::protocol::{Request, Response, MultistreamSelectError};
|
||||
use futures::{prelude::*, sink, Async, StartSend, try_ready};
|
||||
use std::io;
|
||||
use tokio_codec::Encoder;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use unsigned_varint::{decode, codec::Uvi};
|
||||
use std::marker;
|
||||
use unsigned_varint as uvi;
|
||||
|
||||
/// The maximum number of supported protocols that can be processed.
|
||||
const MAX_PROTOCOLS: usize = 1000;
|
||||
|
||||
/// Wraps around a `AsyncRead+AsyncWrite`.
|
||||
/// Assumes that we're on the dialer's side. Produces and accepts messages.
|
||||
pub struct Dialer<R, N> {
|
||||
inner: LengthDelimited<R, MessageEncoder<N>>,
|
||||
handshake_finished: bool
|
||||
inner: LengthDelimited<R>,
|
||||
handshake_finished: bool,
|
||||
_protocol_name: marker::PhantomData<N>,
|
||||
}
|
||||
|
||||
impl<R, N> Dialer<R, N>
|
||||
@ -45,15 +47,16 @@ where
|
||||
N: AsRef<[u8]>
|
||||
{
|
||||
pub fn dial(inner: R) -> DialerFuture<R, N> {
|
||||
let codec = MessageEncoder(std::marker::PhantomData);
|
||||
let sender = LengthDelimited::new(inner, codec);
|
||||
let io = LengthDelimited::new(inner);
|
||||
let mut buf = BytesMut::new();
|
||||
Header::Multistream10.encode(&mut buf);
|
||||
DialerFuture {
|
||||
inner: sender.send(Message::Header)
|
||||
inner: io.send(buf.freeze()),
|
||||
_protocol_name: marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Grants back the socket. Typically used after a `ProtocolAck` has been received.
|
||||
#[inline]
|
||||
pub fn into_inner(self) -> R {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
@ -64,24 +67,22 @@ where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
N: AsRef<[u8]>
|
||||
{
|
||||
type SinkItem = DialerToListenerMessage<N>;
|
||||
type SinkItem = Request<N>;
|
||||
type SinkError = MultistreamSelectError;
|
||||
|
||||
#[inline]
|
||||
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||
match self.inner.start_send(Message::Body(item))? {
|
||||
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
|
||||
AsyncSink::NotReady(Message::Header) => unreachable!(),
|
||||
AsyncSink::Ready => Ok(AsyncSink::Ready)
|
||||
fn start_send(&mut self, request: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||
let mut msg = BytesMut::new();
|
||||
request.encode(&mut msg)?;
|
||||
match self.inner.start_send(msg.freeze())? {
|
||||
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(request)),
|
||||
AsyncSink::Ready => Ok(AsyncSink::Ready),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
||||
Ok(self.inner.poll_complete()?)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
||||
Ok(self.inner.close()?)
|
||||
}
|
||||
@ -91,20 +92,20 @@ impl<R, N> Stream for Dialer<R, N>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite
|
||||
{
|
||||
type Item = ListenerToDialerMessage<Bytes>;
|
||||
type Item = Response<Bytes>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
loop {
|
||||
let mut frame = match self.inner.poll() {
|
||||
Ok(Async::Ready(Some(frame))) => frame,
|
||||
let mut msg = match self.inner.poll() {
|
||||
Ok(Async::Ready(Some(msg))) => msg,
|
||||
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
|
||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
|
||||
if !self.handshake_finished {
|
||||
if frame == MULTISTREAM_PROTOCOL_WITH_LF {
|
||||
if msg == MSG_MULTISTREAM_1_0 {
|
||||
self.handshake_finished = true;
|
||||
continue;
|
||||
} else {
|
||||
@ -112,31 +113,31 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') {
|
||||
let frame_len = frame.len();
|
||||
let protocol = frame.split_to(frame_len - 1);
|
||||
return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck {
|
||||
name: protocol
|
||||
})));
|
||||
} else if frame == b"na\n"[..] {
|
||||
return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable)));
|
||||
if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') {
|
||||
let len = msg.len();
|
||||
let name = msg.split_to(len - 1);
|
||||
return Ok(Async::Ready(Some(
|
||||
Response::Protocol { name }
|
||||
)));
|
||||
} else if msg == MSG_PROTOCOL_NA {
|
||||
return Ok(Async::Ready(Some(Response::ProtocolNotAvailable)));
|
||||
} else {
|
||||
// A varint number of protocols
|
||||
let (num_protocols, mut remaining) = decode::usize(&frame)?;
|
||||
if num_protocols > 1000 { // TODO: configurable limit
|
||||
return Err(MultistreamSelectError::VarintParseError("too many protocols".into()))
|
||||
let (num_protocols, mut remaining) = uvi::decode::usize(&msg)?;
|
||||
if num_protocols > MAX_PROTOCOLS { // TODO: configurable limit
|
||||
return Err(MultistreamSelectError::TooManyProtocols)
|
||||
}
|
||||
let mut out = Vec::with_capacity(num_protocols);
|
||||
let mut protocols = Vec::with_capacity(num_protocols);
|
||||
for _ in 0 .. num_protocols {
|
||||
let (len, rem) = decode::usize(remaining)?;
|
||||
let (len, rem) = uvi::decode::usize(remaining)?;
|
||||
if len == 0 || len > rem.len() || rem[len - 1] != b'\n' {
|
||||
return Err(MultistreamSelectError::UnknownMessage)
|
||||
}
|
||||
out.push(Bytes::from(&rem[.. len - 1]));
|
||||
protocols.push(Bytes::from(&rem[.. len - 1]));
|
||||
remaining = &rem[len ..]
|
||||
}
|
||||
return Ok(Async::Ready(Some(
|
||||
ListenerToDialerMessage::ProtocolsListResponse { list: out },
|
||||
Response::SupportedProtocols { protocols },
|
||||
)));
|
||||
}
|
||||
}
|
||||
@ -145,7 +146,8 @@ where
|
||||
|
||||
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
|
||||
pub struct DialerFuture<T: AsyncWrite, N: AsRef<[u8]>> {
|
||||
inner: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
|
||||
inner: sink::Send<LengthDelimited<T>>,
|
||||
_protocol_name: marker::PhantomData<N>,
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
|
||||
@ -154,57 +156,17 @@ impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
let inner = try_ready!(self.inner.poll());
|
||||
Ok(Async::Ready(Dialer { inner, handshake_finished: false }))
|
||||
}
|
||||
}
|
||||
|
||||
/// tokio-codec `Encoder` handling `DialerToListenerMessage` values.
|
||||
struct MessageEncoder<N>(std::marker::PhantomData<N>);
|
||||
|
||||
enum Message<N> {
|
||||
Header,
|
||||
Body(DialerToListenerMessage<N>)
|
||||
}
|
||||
|
||||
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
|
||||
type Item = Message<N>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
match item {
|
||||
Message::Header => {
|
||||
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
|
||||
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
|
||||
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(DialerToListenerMessage::ProtocolRequest { name }) => {
|
||||
if !name.as_ref().starts_with(b"/") {
|
||||
return Err(MultistreamSelectError::WrongProtocolName)
|
||||
}
|
||||
let len = name.as_ref().len() + 1; // + 1 for \n
|
||||
if len > std::u16::MAX as usize {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
|
||||
}
|
||||
Uvi::<usize>::default().encode(len, dest)?;
|
||||
dest.reserve(len);
|
||||
dest.put(name.as_ref());
|
||||
dest.put(&b"\n"[..]);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(DialerToListenerMessage::ProtocolsListRequest) => {
|
||||
Uvi::<usize>::default().encode(3, dest)?;
|
||||
dest.reserve(3);
|
||||
dest.put(&b"ls\n"[..]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
Ok(Async::Ready(Dialer {
|
||||
inner,
|
||||
handshake_finished: false,
|
||||
_protocol_name: marker::PhantomData,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
|
||||
use super::*;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_tcp::{TcpListener, TcpStream};
|
||||
use futures::Future;
|
||||
@ -225,13 +187,13 @@ mod tests {
|
||||
.from_err()
|
||||
.and_then(move |stream| Dialer::dial(stream))
|
||||
.and_then(move |dialer| {
|
||||
let p = b"invalid_name";
|
||||
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
|
||||
let name = b"invalid_name";
|
||||
dialer.send(Request::Protocol { name })
|
||||
});
|
||||
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
match rt.block_on(server.join(client)) {
|
||||
Err(MultistreamSelectError::WrongProtocolName) => (),
|
||||
Err(MultistreamSelectError::InvalidProtocolName) => (),
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
@ -20,7 +20,7 @@
|
||||
|
||||
//! Contains the error structs for the low-level protocol handling.
|
||||
|
||||
use std::error;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::io;
|
||||
use unsigned_varint::decode;
|
||||
@ -38,29 +38,25 @@ pub enum MultistreamSelectError {
|
||||
UnknownMessage,
|
||||
|
||||
/// Protocol names must always start with `/`, otherwise this error is returned.
|
||||
WrongProtocolName,
|
||||
InvalidProtocolName,
|
||||
|
||||
/// Failure to parse variable-length integer.
|
||||
// TODO: we don't include the actual error, because that would remove Send from the enum
|
||||
VarintParseError(String),
|
||||
/// Too many protocols have been returned by the remote.
|
||||
TooManyProtocols,
|
||||
}
|
||||
|
||||
impl From<io::Error> for MultistreamSelectError {
|
||||
#[inline]
|
||||
fn from(err: io::Error) -> MultistreamSelectError {
|
||||
MultistreamSelectError::IoError(err)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<decode::Error> for MultistreamSelectError {
|
||||
#[inline]
|
||||
fn from(err: decode::Error) -> MultistreamSelectError {
|
||||
MultistreamSelectError::VarintParseError(err.to_string())
|
||||
Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl error::Error for MultistreamSelectError {
|
||||
#[inline]
|
||||
impl Error for MultistreamSelectError {
|
||||
fn description(&self) -> &str {
|
||||
match *self {
|
||||
MultistreamSelectError::IoError(_) => "I/O error",
|
||||
@ -68,16 +64,15 @@ impl error::Error for MultistreamSelectError {
|
||||
"the remote doesn't use the same multistream-select protocol as we do"
|
||||
}
|
||||
MultistreamSelectError::UnknownMessage => "received an unknown message from the remote",
|
||||
MultistreamSelectError::WrongProtocolName => {
|
||||
MultistreamSelectError::InvalidProtocolName => {
|
||||
"protocol names must always start with `/`, otherwise this error is returned"
|
||||
}
|
||||
MultistreamSelectError::VarintParseError(_) => {
|
||||
"failure to parse variable-length integer"
|
||||
}
|
||||
MultistreamSelectError::TooManyProtocols =>
|
||||
"Too many protocols."
|
||||
}
|
||||
}
|
||||
|
||||
fn cause(&self) -> Option<&dyn error::Error> {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match *self {
|
||||
MultistreamSelectError::IoError(ref err) => Some(err),
|
||||
_ => None,
|
||||
@ -86,8 +81,7 @@ impl error::Error for MultistreamSelectError {
|
||||
}
|
||||
|
||||
impl fmt::Display for MultistreamSelectError {
|
||||
#[inline]
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||
write!(fmt, "{}", error::Error::description(self))
|
||||
write!(fmt, "{}", Error::description(self))
|
||||
}
|
||||
}
|
||||
|
@ -20,23 +20,21 @@
|
||||
|
||||
//! Contains the `Listener` wrapper, which allows raw communications with a dialer.
|
||||
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use super::*;
|
||||
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use crate::length_delimited::LengthDelimited;
|
||||
use crate::protocol::DialerToListenerMessage;
|
||||
use crate::protocol::ListenerToDialerMessage;
|
||||
use crate::protocol::MultistreamSelectError;
|
||||
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
|
||||
use crate::protocol::{Request, Response, MultistreamSelectError};
|
||||
use futures::{prelude::*, sink, stream::StreamFuture};
|
||||
use log::{debug, trace};
|
||||
use std::{io, mem};
|
||||
use tokio_codec::Encoder;
|
||||
use std::{marker, mem};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use unsigned_varint::{encode, codec::Uvi};
|
||||
|
||||
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
|
||||
/// accepts messages.
|
||||
pub struct Listener<R, N> {
|
||||
inner: LengthDelimited<R, MessageEncoder<N>>
|
||||
inner: LengthDelimited<R>,
|
||||
_protocol_name: marker::PhantomData<N>,
|
||||
}
|
||||
|
||||
impl<R, N> Listener<R, N>
|
||||
@ -47,16 +45,15 @@ where
|
||||
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
|
||||
/// future returns a `Listener`.
|
||||
pub fn listen(inner: R) -> ListenerFuture<R, N> {
|
||||
let codec = MessageEncoder(std::marker::PhantomData);
|
||||
let inner = LengthDelimited::new(inner, codec);
|
||||
let inner = LengthDelimited::new(inner);
|
||||
ListenerFuture {
|
||||
inner: ListenerFutureState::Await { inner: inner.into_future() }
|
||||
inner: ListenerFutureState::Await { inner: inner.into_future() },
|
||||
_protocol_name: marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a
|
||||
/// `ProtocolAck` has been sent back.
|
||||
#[inline]
|
||||
pub fn into_inner(self) -> R {
|
||||
self.inner.into_inner()
|
||||
}
|
||||
@ -67,24 +64,22 @@ where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
N: AsRef<[u8]>
|
||||
{
|
||||
type SinkItem = ListenerToDialerMessage<N>;
|
||||
type SinkItem = Response<N>;
|
||||
type SinkError = MultistreamSelectError;
|
||||
|
||||
#[inline]
|
||||
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||
match self.inner.start_send(Message::Body(item))? {
|
||||
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
|
||||
AsyncSink::NotReady(Message::Header) => unreachable!(),
|
||||
fn start_send(&mut self, response: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||
let mut msg = BytesMut::new();
|
||||
response.encode(&mut msg)?;
|
||||
match self.inner.start_send(msg.freeze())? {
|
||||
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(response)),
|
||||
AsyncSink::Ready => Ok(AsyncSink::Ready)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
||||
Ok(self.inner.poll_complete()?)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
||||
Ok(self.inner.close()?)
|
||||
}
|
||||
@ -94,26 +89,26 @@ impl<R, N> Stream for Listener<R, N>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = DialerToListenerMessage<Bytes>;
|
||||
type Item = Request<Bytes>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
let mut frame = match self.inner.poll() {
|
||||
Ok(Async::Ready(Some(frame))) => frame,
|
||||
let mut msg = match self.inner.poll() {
|
||||
Ok(Async::Ready(Some(msg))) => msg,
|
||||
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
|
||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||
Err(err) => return Err(err.into()),
|
||||
};
|
||||
|
||||
if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') {
|
||||
let frame_len = frame.len();
|
||||
let protocol = frame.split_to(frame_len - 1);
|
||||
if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') {
|
||||
let len = msg.len();
|
||||
let name = msg.split_to(len - 1);
|
||||
Ok(Async::Ready(Some(
|
||||
DialerToListenerMessage::ProtocolRequest { name: protocol },
|
||||
Request::Protocol { name },
|
||||
)))
|
||||
} else if frame == b"ls\n"[..] {
|
||||
} else if msg == MSG_LS {
|
||||
Ok(Async::Ready(Some(
|
||||
DialerToListenerMessage::ProtocolsListRequest,
|
||||
Request::ListProtocols,
|
||||
)))
|
||||
} else {
|
||||
Err(MultistreamSelectError::UnknownMessage)
|
||||
@ -124,16 +119,17 @@ where
|
||||
|
||||
/// Future, returned by `Listener::new` which performs the handshake and returns
|
||||
/// the `Listener` if successful.
|
||||
pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
|
||||
inner: ListenerFutureState<T, N>
|
||||
pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N> {
|
||||
inner: ListenerFutureState<T>,
|
||||
_protocol_name: marker::PhantomData<N>,
|
||||
}
|
||||
|
||||
enum ListenerFutureState<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
|
||||
enum ListenerFutureState<T: AsyncRead + AsyncWrite> {
|
||||
Await {
|
||||
inner: StreamFuture<LengthDelimited<T, MessageEncoder<N>>>
|
||||
inner: StreamFuture<LengthDelimited<T>>
|
||||
},
|
||||
Reply {
|
||||
sender: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
|
||||
sender: sink::Send<LengthDelimited<T>>
|
||||
},
|
||||
Undefined
|
||||
}
|
||||
@ -155,12 +151,14 @@ impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N>
|
||||
}
|
||||
Err((e, _)) => return Err(MultistreamSelectError::from(e))
|
||||
};
|
||||
if msg.as_ref().map(|b| &b[..]) != Some(MULTISTREAM_PROTOCOL_WITH_LF) {
|
||||
debug!("failed handshake; received: {:?}", msg);
|
||||
if msg.as_ref().map(|b| &b[..]) != Some(MSG_MULTISTREAM_1_0) {
|
||||
debug!("Unexpected message: {:?}", msg);
|
||||
return Err(MultistreamSelectError::FailedHandshake)
|
||||
}
|
||||
trace!("sending back /multistream/<version> to finish the handshake");
|
||||
let sender = socket.send(Message::Header);
|
||||
let mut frame = BytesMut::new();
|
||||
Header::Multistream10.encode(&mut frame);
|
||||
let sender = socket.send(frame.freeze());
|
||||
self.inner = ListenerFutureState::Reply { sender }
|
||||
}
|
||||
ListenerFutureState::Reply { mut sender } => {
|
||||
@ -171,70 +169,13 @@ impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N>
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
};
|
||||
return Ok(Async::Ready(Listener { inner: listener }))
|
||||
return Ok(Async::Ready(Listener {
|
||||
inner: listener,
|
||||
_protocol_name: marker::PhantomData
|
||||
}))
|
||||
}
|
||||
ListenerFutureState::Undefined => panic!("ListenerFutureState::poll called after completion")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// tokio-codec `Encoder` handling `ListenerToDialerMessage` values.
|
||||
struct MessageEncoder<N>(std::marker::PhantomData<N>);
|
||||
|
||||
enum Message<N> {
|
||||
Header,
|
||||
Body(ListenerToDialerMessage<N>)
|
||||
}
|
||||
|
||||
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
|
||||
type Item = Message<N>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
match item {
|
||||
Message::Header => {
|
||||
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
|
||||
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
|
||||
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(ListenerToDialerMessage::ProtocolAck { name }) => {
|
||||
if !name.as_ref().starts_with(b"/") {
|
||||
return Err(MultistreamSelectError::WrongProtocolName)
|
||||
}
|
||||
let len = name.as_ref().len() + 1; // + 1 for \n
|
||||
if len > std::u16::MAX as usize {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
|
||||
}
|
||||
Uvi::<usize>::default().encode(len, dest)?;
|
||||
dest.reserve(len);
|
||||
dest.put(name.as_ref());
|
||||
dest.put(&b"\n"[..]);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(ListenerToDialerMessage::ProtocolsListResponse { list }) => {
|
||||
let mut buf = encode::usize_buffer();
|
||||
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
|
||||
for e in &list {
|
||||
if e.as_ref().len() + 1 > std::u16::MAX as usize {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
|
||||
}
|
||||
out_msg.extend(encode::usize(e.as_ref().len() + 1, &mut buf)); // +1 for '\n'
|
||||
out_msg.extend_from_slice(e.as_ref());
|
||||
out_msg.push(b'\n')
|
||||
}
|
||||
let len = encode::usize(out_msg.len(), &mut buf);
|
||||
dest.reserve(len.len() + out_msg.len());
|
||||
dest.put(len);
|
||||
dest.put(out_msg);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(ListenerToDialerMessage::NotAvailable) => {
|
||||
Uvi::<usize>::default().encode(3, dest)?;
|
||||
dest.reserve(3);
|
||||
dest.put(&b"na\n"[..]);
|
||||
Ok(())
|
||||
ListenerFutureState::Undefined =>
|
||||
panic!("ListenerFutureState::poll called after completion")
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -242,12 +183,12 @@ impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_tcp::{TcpListener, TcpStream};
|
||||
use bytes::Bytes;
|
||||
use futures::Future;
|
||||
use futures::{Sink, Stream};
|
||||
use crate::protocol::{Dialer, Listener, ListenerToDialerMessage, MultistreamSelectError};
|
||||
|
||||
#[test]
|
||||
fn wrong_proto_name() {
|
||||
@ -260,8 +201,8 @@ mod tests {
|
||||
.map_err(|(e, _)| e.into())
|
||||
.and_then(move |(connec, _)| Listener::listen(connec.unwrap()))
|
||||
.and_then(|listener| {
|
||||
let proto_name = Bytes::from("invalid-proto");
|
||||
listener.send(ListenerToDialerMessage::ProtocolAck { name: proto_name })
|
||||
let name = Bytes::from("invalid-proto");
|
||||
listener.send(Response::Protocol { name })
|
||||
});
|
||||
|
||||
let client = TcpStream::connect(&listener_addr)
|
||||
@ -270,7 +211,7 @@ mod tests {
|
||||
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
match rt.block_on(server.join(client)) {
|
||||
Err(MultistreamSelectError::WrongProtocolName) => (),
|
||||
Err(MultistreamSelectError::InvalidProtocolName) => (),
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
@ -20,47 +20,125 @@
|
||||
|
||||
//! Contains lower-level structs to handle the multistream protocol.
|
||||
|
||||
const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
|
||||
const MSG_PROTOCOL_NA: &[u8] = b"na\n";
|
||||
const MSG_LS: &[u8] = b"ls\n";
|
||||
|
||||
mod dialer;
|
||||
mod error;
|
||||
mod listener;
|
||||
|
||||
const MULTISTREAM_PROTOCOL_WITH_LF: &[u8] = b"/multistream/1.0.0\n";
|
||||
|
||||
pub use self::dialer::{Dialer, DialerFuture};
|
||||
pub use self::error::MultistreamSelectError;
|
||||
pub use self::listener::{Listener, ListenerFuture};
|
||||
|
||||
use bytes::{BytesMut, BufMut};
|
||||
use unsigned_varint as uvi;
|
||||
|
||||
pub enum Header {
|
||||
Multistream10
|
||||
}
|
||||
|
||||
impl Header {
|
||||
fn encode(&self, dest: &mut BytesMut) {
|
||||
match self {
|
||||
Header::Multistream10 => {
|
||||
dest.reserve(MSG_MULTISTREAM_1_0.len());
|
||||
dest.put(MSG_MULTISTREAM_1_0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Message sent from the dialer to the listener.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum DialerToListenerMessage<N> {
|
||||
pub enum Request<N> {
|
||||
/// The dialer wants us to use a protocol.
|
||||
///
|
||||
/// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start
|
||||
/// communicating in the new protocol.
|
||||
ProtocolRequest {
|
||||
Protocol {
|
||||
/// Name of the protocol.
|
||||
name: N
|
||||
},
|
||||
|
||||
/// The dialer requested the list of protocols that the listener supports.
|
||||
ProtocolsListRequest,
|
||||
ListProtocols,
|
||||
}
|
||||
|
||||
impl<N: AsRef<[u8]>> Request<N> {
|
||||
fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> {
|
||||
match self {
|
||||
Request::Protocol { name } => {
|
||||
if !name.as_ref().starts_with(b"/") {
|
||||
return Err(MultistreamSelectError::InvalidProtocolName)
|
||||
}
|
||||
let len = name.as_ref().len() + 1; // + 1 for \n
|
||||
dest.reserve(len);
|
||||
dest.put(name.as_ref());
|
||||
dest.put(&b"\n"[..]);
|
||||
Ok(())
|
||||
}
|
||||
Request::ListProtocols => {
|
||||
dest.reserve(MSG_LS.len());
|
||||
dest.put(MSG_LS);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Message sent from the listener to the dialer.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ListenerToDialerMessage<N> {
|
||||
pub enum Response<N> {
|
||||
/// The protocol requested by the dialer is accepted. The socket immediately starts using the
|
||||
/// new protocol.
|
||||
ProtocolAck { name: N },
|
||||
Protocol { name: N },
|
||||
|
||||
/// The protocol requested by the dialer is not supported or available.
|
||||
NotAvailable,
|
||||
ProtocolNotAvailable,
|
||||
|
||||
/// Response to the request for the list of protocols.
|
||||
ProtocolsListResponse {
|
||||
SupportedProtocols {
|
||||
/// The list of protocols.
|
||||
// TODO: use some sort of iterator
|
||||
list: Vec<N>,
|
||||
protocols: Vec<N>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<N: AsRef<[u8]>> Response<N> {
|
||||
fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> {
|
||||
match self {
|
||||
Response::Protocol { name } => {
|
||||
if !name.as_ref().starts_with(b"/") {
|
||||
return Err(MultistreamSelectError::InvalidProtocolName)
|
||||
}
|
||||
let len = name.as_ref().len() + 1; // + 1 for \n
|
||||
dest.reserve(len);
|
||||
dest.put(name.as_ref());
|
||||
dest.put(&b"\n"[..]);
|
||||
Ok(())
|
||||
}
|
||||
Response::SupportedProtocols { protocols } => {
|
||||
let mut buf = uvi::encode::usize_buffer();
|
||||
let mut out_msg = Vec::from(uvi::encode::usize(protocols.len(), &mut buf));
|
||||
for p in protocols {
|
||||
out_msg.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n'
|
||||
out_msg.extend_from_slice(p.as_ref());
|
||||
out_msg.push(b'\n')
|
||||
}
|
||||
dest.reserve(out_msg.len());
|
||||
dest.put(out_msg);
|
||||
Ok(())
|
||||
}
|
||||
Response::ProtocolNotAvailable => {
|
||||
dest.reserve(MSG_PROTOCOL_NA.len());
|
||||
dest.put(MSG_PROTOCOL_NA);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -24,7 +24,7 @@
|
||||
|
||||
use crate::ProtocolChoiceError;
|
||||
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
|
||||
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage};
|
||||
use crate::protocol::{Dialer, Request, Listener, Response};
|
||||
use crate::{dialer_select_proto, listener_select_proto};
|
||||
use futures::prelude::*;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
@ -56,23 +56,23 @@ fn negotiate_with_self_succeeds() {
|
||||
.and_then(|l| l.into_future().map_err(|(e, _)| e))
|
||||
.and_then(|(msg, rest)| {
|
||||
let proto = match msg {
|
||||
Some(DialerToListenerMessage::ProtocolRequest { name }) => name,
|
||||
Some(Request::Protocol { name }) => name,
|
||||
_ => panic!(),
|
||||
};
|
||||
rest.send(ListenerToDialerMessage::ProtocolAck { name: proto })
|
||||
rest.send(Response::Protocol { name: proto })
|
||||
});
|
||||
|
||||
let client = TcpStream::connect(&listener_addr)
|
||||
.from_err()
|
||||
.and_then(move |stream| Dialer::dial(stream))
|
||||
.and_then(move |dialer| {
|
||||
let p = b"/hello/1.0.0";
|
||||
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
|
||||
let name = b"/hello/1.0.0";
|
||||
dialer.send(Request::Protocol { name })
|
||||
})
|
||||
.and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e))
|
||||
.and_then(move |(msg, _)| {
|
||||
let proto = match msg {
|
||||
Some(ListenerToDialerMessage::ProtocolAck { name }) => name,
|
||||
Some(Response::Protocol { name }) => name,
|
||||
_ => panic!(),
|
||||
};
|
||||
assert_eq!(proto, "/hello/1.0.0");
|
||||
|
Loading…
x
Reference in New Issue
Block a user