mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-06-06 06:31:22 +00:00
Remove tokio-codec dependency from multistream-select. (#1203)
* Remove tokio-codec dependency from multistream-select. In preparation for the eventual switch from tokio to std futures. Includes some initial refactoring in preparation for further work in the context of https://github.com/libp2p/rust-libp2p/issues/659. * Reduce default buffer sizes. * Allow more than one frame to be buffered for sending. * Doc tweaks. * Remove superfluous (duplicated) Message types.
This commit is contained in:
parent
bcc7c4d349
commit
2fd941122a
@ -14,9 +14,8 @@ bytes = "0.4"
|
|||||||
futures = { version = "0.1" }
|
futures = { version = "0.1" }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
smallvec = "0.6"
|
smallvec = "0.6"
|
||||||
tokio-codec = "0.1"
|
|
||||||
tokio-io = "0.1"
|
tokio-io = "0.1"
|
||||||
unsigned-varint = { version = "0.2.1", features = ["codec"] }
|
unsigned-varint = { version = "0.2.2" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tokio = "0.1"
|
tokio = "0.1"
|
||||||
|
@ -22,19 +22,14 @@
|
|||||||
//! `multistream-select` for the dialer.
|
//! `multistream-select` for the dialer.
|
||||||
|
|
||||||
use futures::{future::Either, prelude::*, stream::StreamFuture};
|
use futures::{future::Either, prelude::*, stream::StreamFuture};
|
||||||
use crate::protocol::{
|
use crate::protocol::{Dialer, DialerFuture, Request, Response};
|
||||||
Dialer,
|
|
||||||
DialerFuture,
|
|
||||||
DialerToListenerMessage,
|
|
||||||
ListenerToDialerMessage
|
|
||||||
};
|
|
||||||
use log::trace;
|
use log::trace;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use tokio_io::{AsyncRead, AsyncWrite};
|
||||||
use crate::{Negotiated, ProtocolChoiceError};
|
use crate::{Negotiated, ProtocolChoiceError};
|
||||||
|
|
||||||
/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer
|
/// Future, returned by `dialer_select_proto`, which selects a protocol and dialer
|
||||||
/// either sequentially of by considering all protocols in parallel.
|
/// either sequentially or by considering all protocols in parallel.
|
||||||
pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;
|
pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;
|
||||||
|
|
||||||
/// Helps selecting a protocol amongst the ones supported.
|
/// Helps selecting a protocol amongst the ones supported.
|
||||||
@ -75,7 +70,10 @@ where
|
|||||||
{
|
{
|
||||||
let protocols = protocols.into_iter();
|
let protocols = protocols.into_iter();
|
||||||
DialerSelectSeq {
|
DialerSelectSeq {
|
||||||
inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::dial(inner), protocols }
|
inner: DialerSelectSeqState::AwaitDialer {
|
||||||
|
dialer_fut: Dialer::dial(inner),
|
||||||
|
protocols
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,9 +146,7 @@ where
|
|||||||
}
|
}
|
||||||
DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => {
|
DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => {
|
||||||
trace!("sending {:?}", proto_name.as_ref());
|
trace!("sending {:?}", proto_name.as_ref());
|
||||||
let req = DialerToListenerMessage::ProtocolRequest {
|
let req = Request::Protocol { name: proto_name.clone() };
|
||||||
name: proto_name.clone()
|
|
||||||
};
|
|
||||||
match dialer.start_send(req)? {
|
match dialer.start_send(req)? {
|
||||||
AsyncSink::Ready => {
|
AsyncSink::Ready => {
|
||||||
self.inner = DialerSelectSeqState::FlushProtocol {
|
self.inner = DialerSelectSeqState::FlushProtocol {
|
||||||
@ -204,12 +200,12 @@ where
|
|||||||
};
|
};
|
||||||
trace!("received {:?}", m);
|
trace!("received {:?}", m);
|
||||||
match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? {
|
match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? {
|
||||||
ListenerToDialerMessage::ProtocolAck { ref name }
|
Response::Protocol { ref name }
|
||||||
if name.as_ref() == proto_name.as_ref() =>
|
if name.as_ref() == proto_name.as_ref() =>
|
||||||
{
|
{
|
||||||
return Ok(Async::Ready((proto_name, Negotiated(r.into_inner()))))
|
return Ok(Async::Ready((proto_name, Negotiated(r.into_inner()))))
|
||||||
}
|
}
|
||||||
ListenerToDialerMessage::NotAvailable => {
|
Response::ProtocolNotAvailable => {
|
||||||
let proto_name = protocols.next()
|
let proto_name = protocols.next()
|
||||||
.ok_or(ProtocolChoiceError::NoProtocolFound)?;
|
.ok_or(ProtocolChoiceError::NoProtocolFound)?;
|
||||||
self.inner = DialerSelectSeqState::NextProtocol {
|
self.inner = DialerSelectSeqState::NextProtocol {
|
||||||
@ -244,9 +240,8 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in
|
/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in
|
||||||
/// parellel, by first requesting the liste of protocols supported by the remote endpoint and
|
/// parallel, by first requesting the list of protocols supported by the remote endpoint and
|
||||||
/// then selecting the most appropriate one by applying a match predicate to the result.
|
/// then selecting the most appropriate one by applying a match predicate to the result.
|
||||||
pub struct DialerSelectPar<R, I>
|
pub struct DialerSelectPar<R, I>
|
||||||
where
|
where
|
||||||
@ -319,7 +314,7 @@ where
|
|||||||
}
|
}
|
||||||
DialerSelectParState::ProtocolList { mut dialer, protocols } => {
|
DialerSelectParState::ProtocolList { mut dialer, protocols } => {
|
||||||
trace!("requesting protocols list");
|
trace!("requesting protocols list");
|
||||||
match dialer.start_send(DialerToListenerMessage::ProtocolsListRequest)? {
|
match dialer.start_send(Request::ListProtocols)? {
|
||||||
AsyncSink::Ready => {
|
AsyncSink::Ready => {
|
||||||
self.inner = DialerSelectParState::FlushListRequest {
|
self.inner = DialerSelectParState::FlushListRequest {
|
||||||
dialer,
|
dialer,
|
||||||
@ -359,15 +354,15 @@ where
|
|||||||
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
|
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
|
||||||
};
|
};
|
||||||
trace!("protocols list response: {:?}", resp);
|
trace!("protocols list response: {:?}", resp);
|
||||||
let list =
|
let supported =
|
||||||
if let Some(ListenerToDialerMessage::ProtocolsListResponse { list }) = resp {
|
if let Some(Response::SupportedProtocols { protocols }) = resp {
|
||||||
list
|
protocols
|
||||||
} else {
|
} else {
|
||||||
return Err(ProtocolChoiceError::UnexpectedMessage)
|
return Err(ProtocolChoiceError::UnexpectedMessage)
|
||||||
};
|
};
|
||||||
let mut found = None;
|
let mut found = None;
|
||||||
for local_name in protocols {
|
for local_name in protocols {
|
||||||
for remote_name in &list {
|
for remote_name in &supported {
|
||||||
if remote_name.as_ref() == local_name.as_ref() {
|
if remote_name.as_ref() == local_name.as_ref() {
|
||||||
found = Some(local_name);
|
found = Some(local_name);
|
||||||
break;
|
break;
|
||||||
@ -381,10 +376,8 @@ where
|
|||||||
self.inner = DialerSelectParState::Protocol { dialer, proto_name }
|
self.inner = DialerSelectParState::Protocol { dialer, proto_name }
|
||||||
}
|
}
|
||||||
DialerSelectParState::Protocol { mut dialer, proto_name } => {
|
DialerSelectParState::Protocol { mut dialer, proto_name } => {
|
||||||
trace!("requesting protocol: {:?}", proto_name.as_ref());
|
trace!("Requesting protocol: {:?}", proto_name.as_ref());
|
||||||
let req = DialerToListenerMessage::ProtocolRequest {
|
let req = Request::Protocol { name: proto_name.clone() };
|
||||||
name: proto_name.clone()
|
|
||||||
};
|
|
||||||
match dialer.start_send(req)? {
|
match dialer.start_send(req)? {
|
||||||
AsyncSink::Ready => {
|
AsyncSink::Ready => {
|
||||||
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }
|
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }
|
||||||
@ -420,7 +413,7 @@ where
|
|||||||
};
|
};
|
||||||
trace!("received {:?}", resp);
|
trace!("received {:?}", resp);
|
||||||
match resp {
|
match resp {
|
||||||
Some(ListenerToDialerMessage::ProtocolAck { ref name })
|
Some(Response::Protocol { ref name })
|
||||||
if name.as_ref() == proto_name.as_ref() =>
|
if name.as_ref() == proto_name.as_ref() =>
|
||||||
{
|
{
|
||||||
return Ok(Async::Ready((proto_name, Negotiated(dialer.into_inner()))))
|
return Ok(Async::Ready((proto_name, Negotiated(dialer.into_inner()))))
|
||||||
|
@ -21,9 +21,8 @@
|
|||||||
//! Main `ProtocolChoiceError` error.
|
//! Main `ProtocolChoiceError` error.
|
||||||
|
|
||||||
use crate::protocol::MultistreamSelectError;
|
use crate::protocol::MultistreamSelectError;
|
||||||
use std::error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::{fmt, io};
|
||||||
use std::io::Error as IoError;
|
|
||||||
|
|
||||||
/// Error that can happen when negotiating a protocol with the remote.
|
/// Error that can happen when negotiating a protocol with the remote.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -39,21 +38,18 @@ pub enum ProtocolChoiceError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl From<MultistreamSelectError> for ProtocolChoiceError {
|
impl From<MultistreamSelectError> for ProtocolChoiceError {
|
||||||
#[inline]
|
|
||||||
fn from(err: MultistreamSelectError) -> ProtocolChoiceError {
|
fn from(err: MultistreamSelectError) -> ProtocolChoiceError {
|
||||||
ProtocolChoiceError::MultistreamSelectError(err)
|
ProtocolChoiceError::MultistreamSelectError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<IoError> for ProtocolChoiceError {
|
impl From<io::Error> for ProtocolChoiceError {
|
||||||
#[inline]
|
fn from(err: io::Error) -> ProtocolChoiceError {
|
||||||
fn from(err: IoError) -> ProtocolChoiceError {
|
|
||||||
MultistreamSelectError::from(err).into()
|
MultistreamSelectError::from(err).into()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl error::Error for ProtocolChoiceError {
|
impl Error for ProtocolChoiceError {
|
||||||
#[inline]
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
match *self {
|
match *self {
|
||||||
ProtocolChoiceError::MultistreamSelectError(_) => "error in the protocol",
|
ProtocolChoiceError::MultistreamSelectError(_) => "error in the protocol",
|
||||||
@ -66,7 +62,7 @@ impl error::Error for ProtocolChoiceError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cause(&self) -> Option<&dyn error::Error> {
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
match *self {
|
match *self {
|
||||||
ProtocolChoiceError::MultistreamSelectError(ref err) => Some(err),
|
ProtocolChoiceError::MultistreamSelectError(ref err) => Some(err),
|
||||||
_ => None,
|
_ => None,
|
||||||
@ -75,8 +71,7 @@ impl error::Error for ProtocolChoiceError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for ProtocolChoiceError {
|
impl fmt::Display for ProtocolChoiceError {
|
||||||
#[inline]
|
|
||||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||||
write!(fmt, "{}", error::Error::description(self))
|
write!(fmt, "{}", Error::description(self))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,78 +18,81 @@
|
|||||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||||
// DEALINGS IN THE SOFTWARE.
|
// DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
use bytes::Bytes;
|
use bytes::{Bytes, BytesMut, BufMut};
|
||||||
use futures::{Async, Poll, Sink, StartSend, Stream};
|
use futures::{try_ready, Async, Poll, Sink, StartSend, Stream, AsyncSink};
|
||||||
use smallvec::SmallVec;
|
|
||||||
use std::{io, u16};
|
use std::{io, u16};
|
||||||
use tokio_codec::{Encoder, FramedWrite};
|
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use tokio_io::{AsyncRead, AsyncWrite};
|
||||||
use unsigned_varint::decode;
|
use unsigned_varint as uvi;
|
||||||
|
|
||||||
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
|
const MAX_LEN_BYTES: u16 = 2;
|
||||||
|
const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
|
||||||
|
const DEFAULT_BUFFER_SIZE: usize = 64;
|
||||||
|
|
||||||
|
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` resource to read
|
||||||
/// and write unsigned-varint prefixed frames.
|
/// and write unsigned-varint prefixed frames.
|
||||||
///
|
///
|
||||||
/// We purposely only support a frame length of under 64kiB. Frames mostly consist
|
/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
|
||||||
/// in a short protocol name, which is highly unlikely to be more than 64kiB long.
|
/// frame length). Frames mostly consist in a short protocol name, which is highly
|
||||||
pub struct LengthDelimited<R, C> {
|
/// unlikely to be more than 16KiB long.
|
||||||
// The inner socket where data is pulled from.
|
pub struct LengthDelimited<R> {
|
||||||
inner: FramedWrite<R, C>,
|
/// The inner I/O resource.
|
||||||
// Intermediary buffer where we put either the length of the next frame of data, or the frame
|
inner: R,
|
||||||
// of data itself before it is returned.
|
/// Read buffer for a single incoming unsigned-varint length-delimited frame.
|
||||||
// Must always contain enough space to read data from `inner`.
|
read_buffer: BytesMut,
|
||||||
internal_buffer: SmallVec<[u8; 64]>,
|
/// Write buffer for outgoing unsigned-varint length-delimited frames.
|
||||||
// Number of bytes within `internal_buffer` that contain valid data.
|
write_buffer: BytesMut,
|
||||||
internal_buffer_pos: usize,
|
/// The current read state, alternating between reading a frame
|
||||||
// State of the decoder.
|
/// length and reading a frame payload.
|
||||||
state: State
|
read_state: ReadState,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||||
enum State {
|
enum ReadState {
|
||||||
// We are currently reading the length of the next frame of data.
|
/// We are currently reading the length of the next frame of data.
|
||||||
ReadingLength,
|
ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize },
|
||||||
// We are currently reading the frame of data itself.
|
/// We are currently reading the frame of data itself.
|
||||||
ReadingData { frame_len: u16 },
|
ReadData { len: u16, pos: usize },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R, C> LengthDelimited<R, C>
|
impl Default for ReadState {
|
||||||
where
|
fn default() -> Self {
|
||||||
R: AsyncWrite,
|
ReadState::ReadLength {
|
||||||
C: Encoder
|
buf: [0; MAX_LEN_BYTES as usize],
|
||||||
{
|
pos: 0
|
||||||
pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> LengthDelimited<R> {
|
||||||
|
/// Creates a new I/O resource for reading and writing unsigned-varint
|
||||||
|
/// length delimited frames.
|
||||||
|
pub fn new(inner: R) -> LengthDelimited<R> {
|
||||||
LengthDelimited {
|
LengthDelimited {
|
||||||
inner: FramedWrite::new(inner, codec),
|
inner,
|
||||||
internal_buffer: {
|
read_state: ReadState::default(),
|
||||||
let mut v = SmallVec::new();
|
read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
|
||||||
v.push(0);
|
write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
|
||||||
v
|
|
||||||
},
|
|
||||||
internal_buffer_pos: 0,
|
|
||||||
state: State::ReadingLength
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Destroys the `LengthDelimited` and returns the underlying socket.
|
/// Destroys the `LengthDelimited` and returns the underlying socket.
|
||||||
///
|
///
|
||||||
/// Contrary to its equivalent `tokio_io::codec::length_delimited::FramedRead`, this method is
|
/// This method is guaranteed not to skip any data from the socket.
|
||||||
/// guaranteed not to skip any data from the socket.
|
|
||||||
///
|
///
|
||||||
/// # Panic
|
/// # Panic
|
||||||
///
|
///
|
||||||
/// Will panic if called while there is data inside the buffer. **This can only happen if
|
/// Will panic if called while there is data inside the read or write buffer.
|
||||||
/// you call `poll()` manually**. Using this struct as it is intended to be used (i.e. through
|
/// **This can only happen if you call `poll()` manually**. Using this struct
|
||||||
/// the modifiers provided by the `futures` crate) will always leave the object in a state in
|
/// as it is intended to be used (i.e. through the high-level `futures` API)
|
||||||
/// which `into_inner()` will not panic.
|
/// will always leave the object in a state in which `into_inner()` will not panic.
|
||||||
#[inline]
|
|
||||||
pub fn into_inner(self) -> R {
|
pub fn into_inner(self) -> R {
|
||||||
assert_eq!(self.state, State::ReadingLength);
|
assert!(self.write_buffer.is_empty());
|
||||||
assert_eq!(self.internal_buffer_pos, 0);
|
assert!(self.read_buffer.is_empty());
|
||||||
self.inner.into_inner()
|
self.inner
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R, C> Stream for LengthDelimited<R, C>
|
impl<R> Stream for LengthDelimited<R>
|
||||||
where
|
where
|
||||||
R: AsyncRead
|
R: AsyncRead
|
||||||
{
|
{
|
||||||
@ -98,16 +101,11 @@ where
|
|||||||
|
|
||||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||||
loop {
|
loop {
|
||||||
debug_assert!(!self.internal_buffer.is_empty());
|
match &mut self.read_state {
|
||||||
debug_assert!(self.internal_buffer_pos < self.internal_buffer.len());
|
ReadState::ReadLength { buf, pos } => {
|
||||||
|
match self.inner.read(&mut buf[*pos .. *pos + 1]) {
|
||||||
match self.state {
|
|
||||||
State::ReadingLength => {
|
|
||||||
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
|
|
||||||
match self.inner.get_mut().read(slice) {
|
|
||||||
Ok(0) => {
|
Ok(0) => {
|
||||||
// EOF
|
if *pos == 0 {
|
||||||
if self.internal_buffer_pos == 0 {
|
|
||||||
return Ok(Async::Ready(None));
|
return Ok(Async::Ready(None));
|
||||||
} else {
|
} else {
|
||||||
return Err(io::ErrorKind::UnexpectedEof.into());
|
return Err(io::ErrorKind::UnexpectedEof.into());
|
||||||
@ -115,7 +113,7 @@ where
|
|||||||
}
|
}
|
||||||
Ok(n) => {
|
Ok(n) => {
|
||||||
debug_assert_eq!(n, 1);
|
debug_assert_eq!(n, 1);
|
||||||
self.internal_buffer_pos += n;
|
*pos += n;
|
||||||
}
|
}
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||||
return Ok(Async::NotReady);
|
return Ok(Async::NotReady);
|
||||||
@ -125,56 +123,45 @@ where
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
debug_assert_eq!(self.internal_buffer.len(), self.internal_buffer_pos);
|
if (buf[*pos - 1] & 0x80) == 0 {
|
||||||
|
// MSB is not set, indicating the end of the length prefix.
|
||||||
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 {
|
let (len, _) = uvi::decode::u16(buf).map_err(|e| {
|
||||||
// End of length prefix. Most of the time we will switch to reading data,
|
|
||||||
// but we need to handle a few corner cases first.
|
|
||||||
let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| {
|
|
||||||
log::debug!("invalid length prefix: {}", e);
|
log::debug!("invalid length prefix: {}", e);
|
||||||
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
|
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if frame_len >= 1 {
|
if len >= 1 {
|
||||||
self.state = State::ReadingData { frame_len };
|
self.read_state = ReadState::ReadData { len, pos: 0 };
|
||||||
self.internal_buffer.clear();
|
self.read_buffer.resize(len as usize, 0);
|
||||||
self.internal_buffer.reserve(frame_len as usize);
|
|
||||||
self.internal_buffer.extend((0..frame_len).map(|_| 0));
|
|
||||||
self.internal_buffer_pos = 0;
|
|
||||||
} else {
|
} else {
|
||||||
debug_assert_eq!(frame_len, 0);
|
debug_assert_eq!(len, 0);
|
||||||
self.state = State::ReadingLength;
|
self.read_state = ReadState::default();
|
||||||
self.internal_buffer.clear();
|
return Ok(Async::Ready(Some(Bytes::new())));
|
||||||
self.internal_buffer.push(0);
|
|
||||||
self.internal_buffer_pos = 0;
|
|
||||||
return Ok(Async::Ready(Some(From::from(&[][..]))));
|
|
||||||
}
|
}
|
||||||
} else if self.internal_buffer_pos >= 2 {
|
} else if *pos == MAX_LEN_BYTES as usize {
|
||||||
// Length prefix is too long. See module doc for info about max frame len.
|
// MSB signals more length bytes but we have already read the maximum.
|
||||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long"));
|
// See the module documentation about the max frame len.
|
||||||
} else {
|
return Err(io::Error::new(
|
||||||
// Prepare for next read.
|
io::ErrorKind::InvalidData,
|
||||||
self.internal_buffer.push(0);
|
"Maximum frame length exceeded"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
State::ReadingData { frame_len } => {
|
ReadState::ReadData { len, pos } => {
|
||||||
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
|
match self.inner.read(&mut self.read_buffer[*pos..]) {
|
||||||
match self.inner.get_mut().read(slice) {
|
|
||||||
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
|
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
|
||||||
Ok(n) => self.internal_buffer_pos += n,
|
Ok(n) => *pos += n,
|
||||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
Err(err) =>
|
||||||
return Ok(Async::NotReady)
|
if err.kind() == io::ErrorKind::WouldBlock {
|
||||||
}
|
return Ok(Async::NotReady)
|
||||||
Err(err) => return Err(err)
|
} else {
|
||||||
|
return Err(err)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
if self.internal_buffer_pos >= frame_len as usize {
|
if *pos == *len as usize {
|
||||||
// Finished reading the frame of data.
|
// Finished reading the frame.
|
||||||
self.state = State::ReadingLength;
|
let frame = self.read_buffer.split_off(0).freeze();
|
||||||
let out_data = From::from(&self.internal_buffer[..]);
|
self.read_state = ReadState::default();
|
||||||
self.internal_buffer.clear();
|
return Ok(Async::Ready(Some(frame)));
|
||||||
self.internal_buffer.push(0);
|
|
||||||
self.internal_buffer_pos = 0;
|
|
||||||
return Ok(Async::Ready(Some(out_data)));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -182,27 +169,60 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R, C> Sink for LengthDelimited<R, C>
|
impl<R> Sink for LengthDelimited<R>
|
||||||
where
|
where
|
||||||
R: AsyncWrite,
|
R: AsyncWrite,
|
||||||
C: Encoder
|
|
||||||
{
|
{
|
||||||
type SinkItem = <FramedWrite<R, C> as Sink>::SinkItem;
|
type SinkItem = Bytes;
|
||||||
type SinkError = <FramedWrite<R, C> as Sink>::SinkError;
|
type SinkError = io::Error;
|
||||||
|
|
||||||
#[inline]
|
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||||
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
// Use the maximum frame length also as a (soft) upper limit
|
||||||
self.inner.start_send(item)
|
// for the entire write buffer. The actual (hard) limit is thus
|
||||||
|
// implied to be roughly 2 * MAX_FRAME_SIZE.
|
||||||
|
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
|
||||||
|
self.poll_complete()?;
|
||||||
|
if self.write_buffer.len() >= MAX_FRAME_SIZE as usize {
|
||||||
|
return Ok(AsyncSink::NotReady(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let len = msg.len() as u16;
|
||||||
|
if len > MAX_FRAME_SIZE {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::InvalidData,
|
||||||
|
"Maximum frame size exceeded."))
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut uvi_buf = uvi::encode::u16_buffer();
|
||||||
|
let uvi_len = uvi::encode::u16(len, &mut uvi_buf);
|
||||||
|
self.write_buffer.reserve(len as usize + uvi_len.len());
|
||||||
|
self.write_buffer.put(uvi_len);
|
||||||
|
self.write_buffer.put(msg);
|
||||||
|
|
||||||
|
Ok(AsyncSink::Ready)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
||||||
self.inner.poll_complete()
|
while !self.write_buffer.is_empty() {
|
||||||
|
let n = try_ready!(self.inner.poll_write(&self.write_buffer));
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::WriteZero,
|
||||||
|
"Failed to write buffered frame."))
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = self.write_buffer.split_to(n);
|
||||||
|
}
|
||||||
|
|
||||||
|
try_ready!(self.inner.poll_flush());
|
||||||
|
return Ok(Async::Ready(()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
||||||
self.inner.close()
|
try_ready!(self.poll_complete());
|
||||||
|
Ok(self.inner.shutdown()?)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -211,12 +231,11 @@ mod tests {
|
|||||||
use futures::{Future, Stream};
|
use futures::{Future, Stream};
|
||||||
use crate::length_delimited::LengthDelimited;
|
use crate::length_delimited::LengthDelimited;
|
||||||
use std::io::{Cursor, ErrorKind};
|
use std::io::{Cursor, ErrorKind};
|
||||||
use unsigned_varint::codec::UviBytes;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn basic_read() {
|
fn basic_read() {
|
||||||
let data = vec![6, 9, 8, 7, 6, 5, 4];
|
let data = vec![6, 9, 8, 7, 6, 5, 4];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait().unwrap();
|
let recved = framed.collect().wait().unwrap();
|
||||||
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
|
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
|
||||||
}
|
}
|
||||||
@ -224,7 +243,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn basic_read_two() {
|
fn basic_read_two() {
|
||||||
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
|
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait().unwrap();
|
let recved = framed.collect().wait().unwrap();
|
||||||
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
|
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
|
||||||
}
|
}
|
||||||
@ -236,7 +255,7 @@ mod tests {
|
|||||||
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
|
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
|
||||||
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
|
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
|
||||||
data.extend(frame.clone().into_iter());
|
data.extend(frame.clone().into_iter());
|
||||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed
|
let recved = framed
|
||||||
.into_future()
|
.into_future()
|
||||||
.map(|(m, _)| m)
|
.map(|(m, _)| m)
|
||||||
@ -250,7 +269,7 @@ mod tests {
|
|||||||
fn packet_len_too_long() {
|
fn packet_len_too_long() {
|
||||||
let mut data = vec![0x81, 0x81, 0x1];
|
let mut data = vec![0x81, 0x81, 0x1];
|
||||||
data.extend((0..16513).map(|_| 0));
|
data.extend((0..16513).map(|_| 0));
|
||||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed
|
let recved = framed
|
||||||
.into_future()
|
.into_future()
|
||||||
.map(|(m, _)| m)
|
.map(|(m, _)| m)
|
||||||
@ -267,7 +286,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn empty_frames() {
|
fn empty_frames() {
|
||||||
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
|
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait().unwrap();
|
let recved = framed.collect().wait().unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
recved,
|
recved,
|
||||||
@ -284,7 +303,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn unexpected_eof_in_len() {
|
fn unexpected_eof_in_len() {
|
||||||
let data = vec![0x89];
|
let data = vec![0x89];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait();
|
let recved = framed.collect().wait();
|
||||||
if let Err(io_err) = recved {
|
if let Err(io_err) = recved {
|
||||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||||
@ -296,7 +315,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn unexpected_eof_in_data() {
|
fn unexpected_eof_in_data() {
|
||||||
let data = vec![5];
|
let data = vec![5];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait();
|
let recved = framed.collect().wait();
|
||||||
if let Err(io_err) = recved {
|
if let Err(io_err) = recved {
|
||||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||||
@ -308,7 +327,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn unexpected_eof_in_data2() {
|
fn unexpected_eof_in_data2() {
|
||||||
let data = vec![5, 9, 8, 7];
|
let data = vec![5, 9, 8, 7];
|
||||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
let framed = LengthDelimited::new(Cursor::new(data));
|
||||||
let recved = framed.collect().wait();
|
let recved = framed.collect().wait();
|
||||||
if let Err(io_err) = recved {
|
if let Err(io_err) = recved {
|
||||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||||
|
@ -21,7 +21,7 @@
|
|||||||
//! # Multistream-select
|
//! # Multistream-select
|
||||||
//!
|
//!
|
||||||
//! This crate implements the `multistream-select` protocol, which is the protocol used by libp2p
|
//! This crate implements the `multistream-select` protocol, which is the protocol used by libp2p
|
||||||
//! to negotiate which protocol to use with the remote.
|
//! to negotiate which protocol to use with the remote on a connection or substream.
|
||||||
//!
|
//!
|
||||||
//! > **Note**: This crate is used by the internals of *libp2p*, and it is not required to
|
//! > **Note**: This crate is used by the internals of *libp2p*, and it is not required to
|
||||||
//! > understand it in order to use *libp2p*.
|
//! > understand it in order to use *libp2p*.
|
||||||
@ -76,6 +76,7 @@ mod protocol;
|
|||||||
|
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use std::io;
|
use std::io;
|
||||||
|
use tokio_io::{AsyncRead, AsyncWrite};
|
||||||
|
|
||||||
pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture};
|
pub use self::dialer_select::{dialer_select_proto, DialerSelectFuture};
|
||||||
pub use self::error::ProtocolChoiceError;
|
pub use self::error::ProtocolChoiceError;
|
||||||
@ -93,9 +94,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TInner> tokio_io::AsyncRead for Negotiated<TInner>
|
impl<TInner> AsyncRead for Negotiated<TInner>
|
||||||
where
|
where
|
||||||
TInner: tokio_io::AsyncRead
|
TInner: AsyncRead
|
||||||
{
|
{
|
||||||
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
|
||||||
self.0.prepare_uninitialized_buffer(buf)
|
self.0.prepare_uninitialized_buffer(buf)
|
||||||
@ -119,9 +120,9 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<TInner> tokio_io::AsyncWrite for Negotiated<TInner>
|
impl<TInner> AsyncWrite for Negotiated<TInner>
|
||||||
where
|
where
|
||||||
TInner: tokio_io::AsyncWrite
|
TInner: AsyncWrite
|
||||||
{
|
{
|
||||||
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
fn shutdown(&mut self) -> Poll<(), io::Error> {
|
||||||
self.0.shutdown()
|
self.0.shutdown()
|
||||||
|
@ -23,10 +23,10 @@
|
|||||||
|
|
||||||
use futures::{prelude::*, sink, stream::StreamFuture};
|
use futures::{prelude::*, sink, stream::StreamFuture};
|
||||||
use crate::protocol::{
|
use crate::protocol::{
|
||||||
DialerToListenerMessage,
|
Request,
|
||||||
|
Response,
|
||||||
Listener,
|
Listener,
|
||||||
ListenerFuture,
|
ListenerFuture,
|
||||||
ListenerToDialerMessage
|
|
||||||
};
|
};
|
||||||
use log::{debug, trace};
|
use log::{debug, trace};
|
||||||
use std::mem;
|
use std::mem;
|
||||||
@ -126,13 +126,13 @@ where
|
|||||||
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
|
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
|
||||||
};
|
};
|
||||||
match msg {
|
match msg {
|
||||||
Some(DialerToListenerMessage::ProtocolsListRequest) => {
|
Some(Request::ListProtocols) => {
|
||||||
trace!("protocols list response: {:?}", protocols
|
trace!("protocols list response: {:?}", protocols
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|p| p.as_ref().into())
|
.map(|p| p.as_ref().into())
|
||||||
.collect::<Vec<Vec<u8>>>());
|
.collect::<Vec<Vec<u8>>>());
|
||||||
let list = protocols.into_iter().collect();
|
let supported = protocols.into_iter().collect();
|
||||||
let msg = ListenerToDialerMessage::ProtocolsListResponse { list };
|
let msg = Response::SupportedProtocols { protocols: supported };
|
||||||
let sender = listener.send(msg);
|
let sender = listener.send(msg);
|
||||||
self.inner = ListenerSelectState::Outgoing {
|
self.inner = ListenerSelectState::Outgoing {
|
||||||
sender,
|
sender,
|
||||||
@ -140,12 +140,12 @@ where
|
|||||||
outcome: None
|
outcome: None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(DialerToListenerMessage::ProtocolRequest { name }) => {
|
Some(Request::Protocol { name }) => {
|
||||||
let mut outcome = None;
|
let mut outcome = None;
|
||||||
let mut send_back = ListenerToDialerMessage::NotAvailable;
|
let mut send_back = Response::ProtocolNotAvailable;
|
||||||
for supported in &protocols {
|
for supported in &protocols {
|
||||||
if name.as_ref() == supported.as_ref() {
|
if name.as_ref() == supported.as_ref() {
|
||||||
send_back = ListenerToDialerMessage::ProtocolAck {
|
send_back = Response::Protocol {
|
||||||
name: supported.clone()
|
name: supported.clone()
|
||||||
};
|
};
|
||||||
outcome = Some(supported);
|
outcome = Some(supported);
|
||||||
|
@ -20,23 +20,25 @@
|
|||||||
|
|
||||||
//! Contains the `Dialer` wrapper, which allows raw communications with a listener.
|
//! Contains the `Dialer` wrapper, which allows raw communications with a listener.
|
||||||
|
|
||||||
use bytes::{BufMut, Bytes, BytesMut};
|
use super::*;
|
||||||
|
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
use crate::length_delimited::LengthDelimited;
|
use crate::length_delimited::LengthDelimited;
|
||||||
use crate::protocol::DialerToListenerMessage;
|
use crate::protocol::{Request, Response, MultistreamSelectError};
|
||||||
use crate::protocol::ListenerToDialerMessage;
|
|
||||||
use crate::protocol::MultistreamSelectError;
|
|
||||||
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
|
|
||||||
use futures::{prelude::*, sink, Async, StartSend, try_ready};
|
use futures::{prelude::*, sink, Async, StartSend, try_ready};
|
||||||
use std::io;
|
|
||||||
use tokio_codec::Encoder;
|
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use tokio_io::{AsyncRead, AsyncWrite};
|
||||||
use unsigned_varint::{decode, codec::Uvi};
|
use std::marker;
|
||||||
|
use unsigned_varint as uvi;
|
||||||
|
|
||||||
|
/// The maximum number of supported protocols that can be processed.
|
||||||
|
const MAX_PROTOCOLS: usize = 1000;
|
||||||
|
|
||||||
/// Wraps around a `AsyncRead+AsyncWrite`.
|
/// Wraps around a `AsyncRead+AsyncWrite`.
|
||||||
/// Assumes that we're on the dialer's side. Produces and accepts messages.
|
/// Assumes that we're on the dialer's side. Produces and accepts messages.
|
||||||
pub struct Dialer<R, N> {
|
pub struct Dialer<R, N> {
|
||||||
inner: LengthDelimited<R, MessageEncoder<N>>,
|
inner: LengthDelimited<R>,
|
||||||
handshake_finished: bool
|
handshake_finished: bool,
|
||||||
|
_protocol_name: marker::PhantomData<N>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R, N> Dialer<R, N>
|
impl<R, N> Dialer<R, N>
|
||||||
@ -45,15 +47,16 @@ where
|
|||||||
N: AsRef<[u8]>
|
N: AsRef<[u8]>
|
||||||
{
|
{
|
||||||
pub fn dial(inner: R) -> DialerFuture<R, N> {
|
pub fn dial(inner: R) -> DialerFuture<R, N> {
|
||||||
let codec = MessageEncoder(std::marker::PhantomData);
|
let io = LengthDelimited::new(inner);
|
||||||
let sender = LengthDelimited::new(inner, codec);
|
let mut buf = BytesMut::new();
|
||||||
|
Header::Multistream10.encode(&mut buf);
|
||||||
DialerFuture {
|
DialerFuture {
|
||||||
inner: sender.send(Message::Header)
|
inner: io.send(buf.freeze()),
|
||||||
|
_protocol_name: marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Grants back the socket. Typically used after a `ProtocolAck` has been received.
|
/// Grants back the socket. Typically used after a `ProtocolAck` has been received.
|
||||||
#[inline]
|
|
||||||
pub fn into_inner(self) -> R {
|
pub fn into_inner(self) -> R {
|
||||||
self.inner.into_inner()
|
self.inner.into_inner()
|
||||||
}
|
}
|
||||||
@ -64,24 +67,22 @@ where
|
|||||||
R: AsyncRead + AsyncWrite,
|
R: AsyncRead + AsyncWrite,
|
||||||
N: AsRef<[u8]>
|
N: AsRef<[u8]>
|
||||||
{
|
{
|
||||||
type SinkItem = DialerToListenerMessage<N>;
|
type SinkItem = Request<N>;
|
||||||
type SinkError = MultistreamSelectError;
|
type SinkError = MultistreamSelectError;
|
||||||
|
|
||||||
#[inline]
|
fn start_send(&mut self, request: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||||
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
let mut msg = BytesMut::new();
|
||||||
match self.inner.start_send(Message::Body(item))? {
|
request.encode(&mut msg)?;
|
||||||
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
|
match self.inner.start_send(msg.freeze())? {
|
||||||
AsyncSink::NotReady(Message::Header) => unreachable!(),
|
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(request)),
|
||||||
AsyncSink::Ready => Ok(AsyncSink::Ready)
|
AsyncSink::Ready => Ok(AsyncSink::Ready),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
||||||
Ok(self.inner.poll_complete()?)
|
Ok(self.inner.poll_complete()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
||||||
Ok(self.inner.close()?)
|
Ok(self.inner.close()?)
|
||||||
}
|
}
|
||||||
@ -91,20 +92,20 @@ impl<R, N> Stream for Dialer<R, N>
|
|||||||
where
|
where
|
||||||
R: AsyncRead + AsyncWrite
|
R: AsyncRead + AsyncWrite
|
||||||
{
|
{
|
||||||
type Item = ListenerToDialerMessage<Bytes>;
|
type Item = Response<Bytes>;
|
||||||
type Error = MultistreamSelectError;
|
type Error = MultistreamSelectError;
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||||
loop {
|
loop {
|
||||||
let mut frame = match self.inner.poll() {
|
let mut msg = match self.inner.poll() {
|
||||||
Ok(Async::Ready(Some(frame))) => frame,
|
Ok(Async::Ready(Some(msg))) => msg,
|
||||||
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
|
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
|
||||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||||
Err(err) => return Err(err.into()),
|
Err(err) => return Err(err.into()),
|
||||||
};
|
};
|
||||||
|
|
||||||
if !self.handshake_finished {
|
if !self.handshake_finished {
|
||||||
if frame == MULTISTREAM_PROTOCOL_WITH_LF {
|
if msg == MSG_MULTISTREAM_1_0 {
|
||||||
self.handshake_finished = true;
|
self.handshake_finished = true;
|
||||||
continue;
|
continue;
|
||||||
} else {
|
} else {
|
||||||
@ -112,31 +113,31 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') {
|
if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') {
|
||||||
let frame_len = frame.len();
|
let len = msg.len();
|
||||||
let protocol = frame.split_to(frame_len - 1);
|
let name = msg.split_to(len - 1);
|
||||||
return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck {
|
return Ok(Async::Ready(Some(
|
||||||
name: protocol
|
Response::Protocol { name }
|
||||||
})));
|
)));
|
||||||
} else if frame == b"na\n"[..] {
|
} else if msg == MSG_PROTOCOL_NA {
|
||||||
return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable)));
|
return Ok(Async::Ready(Some(Response::ProtocolNotAvailable)));
|
||||||
} else {
|
} else {
|
||||||
// A varint number of protocols
|
// A varint number of protocols
|
||||||
let (num_protocols, mut remaining) = decode::usize(&frame)?;
|
let (num_protocols, mut remaining) = uvi::decode::usize(&msg)?;
|
||||||
if num_protocols > 1000 { // TODO: configurable limit
|
if num_protocols > MAX_PROTOCOLS { // TODO: configurable limit
|
||||||
return Err(MultistreamSelectError::VarintParseError("too many protocols".into()))
|
return Err(MultistreamSelectError::TooManyProtocols)
|
||||||
}
|
}
|
||||||
let mut out = Vec::with_capacity(num_protocols);
|
let mut protocols = Vec::with_capacity(num_protocols);
|
||||||
for _ in 0 .. num_protocols {
|
for _ in 0 .. num_protocols {
|
||||||
let (len, rem) = decode::usize(remaining)?;
|
let (len, rem) = uvi::decode::usize(remaining)?;
|
||||||
if len == 0 || len > rem.len() || rem[len - 1] != b'\n' {
|
if len == 0 || len > rem.len() || rem[len - 1] != b'\n' {
|
||||||
return Err(MultistreamSelectError::UnknownMessage)
|
return Err(MultistreamSelectError::UnknownMessage)
|
||||||
}
|
}
|
||||||
out.push(Bytes::from(&rem[.. len - 1]));
|
protocols.push(Bytes::from(&rem[.. len - 1]));
|
||||||
remaining = &rem[len ..]
|
remaining = &rem[len ..]
|
||||||
}
|
}
|
||||||
return Ok(Async::Ready(Some(
|
return Ok(Async::Ready(Some(
|
||||||
ListenerToDialerMessage::ProtocolsListResponse { list: out },
|
Response::SupportedProtocols { protocols },
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -145,7 +146,8 @@ where
|
|||||||
|
|
||||||
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
|
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
|
||||||
pub struct DialerFuture<T: AsyncWrite, N: AsRef<[u8]>> {
|
pub struct DialerFuture<T: AsyncWrite, N: AsRef<[u8]>> {
|
||||||
inner: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
|
inner: sink::Send<LengthDelimited<T>>,
|
||||||
|
_protocol_name: marker::PhantomData<N>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
|
impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
|
||||||
@ -154,57 +156,17 @@ impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
|
|||||||
|
|
||||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||||
let inner = try_ready!(self.inner.poll());
|
let inner = try_ready!(self.inner.poll());
|
||||||
Ok(Async::Ready(Dialer { inner, handshake_finished: false }))
|
Ok(Async::Ready(Dialer {
|
||||||
}
|
inner,
|
||||||
}
|
handshake_finished: false,
|
||||||
|
_protocol_name: marker::PhantomData,
|
||||||
/// tokio-codec `Encoder` handling `DialerToListenerMessage` values.
|
}))
|
||||||
struct MessageEncoder<N>(std::marker::PhantomData<N>);
|
|
||||||
|
|
||||||
enum Message<N> {
|
|
||||||
Header,
|
|
||||||
Body(DialerToListenerMessage<N>)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
|
|
||||||
type Item = Message<N>;
|
|
||||||
type Error = MultistreamSelectError;
|
|
||||||
|
|
||||||
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
|
|
||||||
match item {
|
|
||||||
Message::Header => {
|
|
||||||
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
|
|
||||||
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
|
|
||||||
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Message::Body(DialerToListenerMessage::ProtocolRequest { name }) => {
|
|
||||||
if !name.as_ref().starts_with(b"/") {
|
|
||||||
return Err(MultistreamSelectError::WrongProtocolName)
|
|
||||||
}
|
|
||||||
let len = name.as_ref().len() + 1; // + 1 for \n
|
|
||||||
if len > std::u16::MAX as usize {
|
|
||||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
|
|
||||||
}
|
|
||||||
Uvi::<usize>::default().encode(len, dest)?;
|
|
||||||
dest.reserve(len);
|
|
||||||
dest.put(name.as_ref());
|
|
||||||
dest.put(&b"\n"[..]);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Message::Body(DialerToListenerMessage::ProtocolsListRequest) => {
|
|
||||||
Uvi::<usize>::default().encode(3, dest)?;
|
|
||||||
dest.reserve(3);
|
|
||||||
dest.put(&b"ls\n"[..]);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
|
use super::*;
|
||||||
use tokio::runtime::current_thread::Runtime;
|
use tokio::runtime::current_thread::Runtime;
|
||||||
use tokio_tcp::{TcpListener, TcpStream};
|
use tokio_tcp::{TcpListener, TcpStream};
|
||||||
use futures::Future;
|
use futures::Future;
|
||||||
@ -225,13 +187,13 @@ mod tests {
|
|||||||
.from_err()
|
.from_err()
|
||||||
.and_then(move |stream| Dialer::dial(stream))
|
.and_then(move |stream| Dialer::dial(stream))
|
||||||
.and_then(move |dialer| {
|
.and_then(move |dialer| {
|
||||||
let p = b"invalid_name";
|
let name = b"invalid_name";
|
||||||
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
|
dialer.send(Request::Protocol { name })
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut rt = Runtime::new().unwrap();
|
let mut rt = Runtime::new().unwrap();
|
||||||
match rt.block_on(server.join(client)) {
|
match rt.block_on(server.join(client)) {
|
||||||
Err(MultistreamSelectError::WrongProtocolName) => (),
|
Err(MultistreamSelectError::InvalidProtocolName) => (),
|
||||||
_ => panic!(),
|
_ => panic!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,7 +20,7 @@
|
|||||||
|
|
||||||
//! Contains the error structs for the low-level protocol handling.
|
//! Contains the error structs for the low-level protocol handling.
|
||||||
|
|
||||||
use std::error;
|
use std::error::Error;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
use std::io;
|
use std::io;
|
||||||
use unsigned_varint::decode;
|
use unsigned_varint::decode;
|
||||||
@ -38,29 +38,25 @@ pub enum MultistreamSelectError {
|
|||||||
UnknownMessage,
|
UnknownMessage,
|
||||||
|
|
||||||
/// Protocol names must always start with `/`, otherwise this error is returned.
|
/// Protocol names must always start with `/`, otherwise this error is returned.
|
||||||
WrongProtocolName,
|
InvalidProtocolName,
|
||||||
|
|
||||||
/// Failure to parse variable-length integer.
|
/// Too many protocols have been returned by the remote.
|
||||||
// TODO: we don't include the actual error, because that would remove Send from the enum
|
TooManyProtocols,
|
||||||
VarintParseError(String),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<io::Error> for MultistreamSelectError {
|
impl From<io::Error> for MultistreamSelectError {
|
||||||
#[inline]
|
|
||||||
fn from(err: io::Error) -> MultistreamSelectError {
|
fn from(err: io::Error) -> MultistreamSelectError {
|
||||||
MultistreamSelectError::IoError(err)
|
MultistreamSelectError::IoError(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<decode::Error> for MultistreamSelectError {
|
impl From<decode::Error> for MultistreamSelectError {
|
||||||
#[inline]
|
|
||||||
fn from(err: decode::Error) -> MultistreamSelectError {
|
fn from(err: decode::Error) -> MultistreamSelectError {
|
||||||
MultistreamSelectError::VarintParseError(err.to_string())
|
Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl error::Error for MultistreamSelectError {
|
impl Error for MultistreamSelectError {
|
||||||
#[inline]
|
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
match *self {
|
match *self {
|
||||||
MultistreamSelectError::IoError(_) => "I/O error",
|
MultistreamSelectError::IoError(_) => "I/O error",
|
||||||
@ -68,16 +64,15 @@ impl error::Error for MultistreamSelectError {
|
|||||||
"the remote doesn't use the same multistream-select protocol as we do"
|
"the remote doesn't use the same multistream-select protocol as we do"
|
||||||
}
|
}
|
||||||
MultistreamSelectError::UnknownMessage => "received an unknown message from the remote",
|
MultistreamSelectError::UnknownMessage => "received an unknown message from the remote",
|
||||||
MultistreamSelectError::WrongProtocolName => {
|
MultistreamSelectError::InvalidProtocolName => {
|
||||||
"protocol names must always start with `/`, otherwise this error is returned"
|
"protocol names must always start with `/`, otherwise this error is returned"
|
||||||
}
|
}
|
||||||
MultistreamSelectError::VarintParseError(_) => {
|
MultistreamSelectError::TooManyProtocols =>
|
||||||
"failure to parse variable-length integer"
|
"Too many protocols."
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cause(&self) -> Option<&dyn error::Error> {
|
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||||
match *self {
|
match *self {
|
||||||
MultistreamSelectError::IoError(ref err) => Some(err),
|
MultistreamSelectError::IoError(ref err) => Some(err),
|
||||||
_ => None,
|
_ => None,
|
||||||
@ -86,8 +81,7 @@ impl error::Error for MultistreamSelectError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for MultistreamSelectError {
|
impl fmt::Display for MultistreamSelectError {
|
||||||
#[inline]
|
|
||||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
|
||||||
write!(fmt, "{}", error::Error::description(self))
|
write!(fmt, "{}", Error::description(self))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,23 +20,21 @@
|
|||||||
|
|
||||||
//! Contains the `Listener` wrapper, which allows raw communications with a dialer.
|
//! Contains the `Listener` wrapper, which allows raw communications with a dialer.
|
||||||
|
|
||||||
use bytes::{BufMut, Bytes, BytesMut};
|
use super::*;
|
||||||
|
|
||||||
|
use bytes::{Bytes, BytesMut};
|
||||||
use crate::length_delimited::LengthDelimited;
|
use crate::length_delimited::LengthDelimited;
|
||||||
use crate::protocol::DialerToListenerMessage;
|
use crate::protocol::{Request, Response, MultistreamSelectError};
|
||||||
use crate::protocol::ListenerToDialerMessage;
|
|
||||||
use crate::protocol::MultistreamSelectError;
|
|
||||||
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
|
|
||||||
use futures::{prelude::*, sink, stream::StreamFuture};
|
use futures::{prelude::*, sink, stream::StreamFuture};
|
||||||
use log::{debug, trace};
|
use log::{debug, trace};
|
||||||
use std::{io, mem};
|
use std::{marker, mem};
|
||||||
use tokio_codec::Encoder;
|
|
||||||
use tokio_io::{AsyncRead, AsyncWrite};
|
use tokio_io::{AsyncRead, AsyncWrite};
|
||||||
use unsigned_varint::{encode, codec::Uvi};
|
|
||||||
|
|
||||||
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
|
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
|
||||||
/// accepts messages.
|
/// accepts messages.
|
||||||
pub struct Listener<R, N> {
|
pub struct Listener<R, N> {
|
||||||
inner: LengthDelimited<R, MessageEncoder<N>>
|
inner: LengthDelimited<R>,
|
||||||
|
_protocol_name: marker::PhantomData<N>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R, N> Listener<R, N>
|
impl<R, N> Listener<R, N>
|
||||||
@ -47,16 +45,15 @@ where
|
|||||||
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
|
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
|
||||||
/// future returns a `Listener`.
|
/// future returns a `Listener`.
|
||||||
pub fn listen(inner: R) -> ListenerFuture<R, N> {
|
pub fn listen(inner: R) -> ListenerFuture<R, N> {
|
||||||
let codec = MessageEncoder(std::marker::PhantomData);
|
let inner = LengthDelimited::new(inner);
|
||||||
let inner = LengthDelimited::new(inner, codec);
|
|
||||||
ListenerFuture {
|
ListenerFuture {
|
||||||
inner: ListenerFutureState::Await { inner: inner.into_future() }
|
inner: ListenerFutureState::Await { inner: inner.into_future() },
|
||||||
|
_protocol_name: marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a
|
/// Grants back the socket. Typically used after a `ProtocolRequest` has been received and a
|
||||||
/// `ProtocolAck` has been sent back.
|
/// `ProtocolAck` has been sent back.
|
||||||
#[inline]
|
|
||||||
pub fn into_inner(self) -> R {
|
pub fn into_inner(self) -> R {
|
||||||
self.inner.into_inner()
|
self.inner.into_inner()
|
||||||
}
|
}
|
||||||
@ -67,24 +64,22 @@ where
|
|||||||
R: AsyncRead + AsyncWrite,
|
R: AsyncRead + AsyncWrite,
|
||||||
N: AsRef<[u8]>
|
N: AsRef<[u8]>
|
||||||
{
|
{
|
||||||
type SinkItem = ListenerToDialerMessage<N>;
|
type SinkItem = Response<N>;
|
||||||
type SinkError = MultistreamSelectError;
|
type SinkError = MultistreamSelectError;
|
||||||
|
|
||||||
#[inline]
|
fn start_send(&mut self, response: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||||
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
let mut msg = BytesMut::new();
|
||||||
match self.inner.start_send(Message::Body(item))? {
|
response.encode(&mut msg)?;
|
||||||
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
|
match self.inner.start_send(msg.freeze())? {
|
||||||
AsyncSink::NotReady(Message::Header) => unreachable!(),
|
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(response)),
|
||||||
AsyncSink::Ready => Ok(AsyncSink::Ready)
|
AsyncSink::Ready => Ok(AsyncSink::Ready)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
|
||||||
Ok(self.inner.poll_complete()?)
|
Ok(self.inner.poll_complete()?)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
fn close(&mut self) -> Poll<(), Self::SinkError> {
|
||||||
Ok(self.inner.close()?)
|
Ok(self.inner.close()?)
|
||||||
}
|
}
|
||||||
@ -94,26 +89,26 @@ impl<R, N> Stream for Listener<R, N>
|
|||||||
where
|
where
|
||||||
R: AsyncRead + AsyncWrite,
|
R: AsyncRead + AsyncWrite,
|
||||||
{
|
{
|
||||||
type Item = DialerToListenerMessage<Bytes>;
|
type Item = Request<Bytes>;
|
||||||
type Error = MultistreamSelectError;
|
type Error = MultistreamSelectError;
|
||||||
|
|
||||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||||
let mut frame = match self.inner.poll() {
|
let mut msg = match self.inner.poll() {
|
||||||
Ok(Async::Ready(Some(frame))) => frame,
|
Ok(Async::Ready(Some(msg))) => msg,
|
||||||
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
|
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
|
||||||
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
Ok(Async::NotReady) => return Ok(Async::NotReady),
|
||||||
Err(err) => return Err(err.into()),
|
Err(err) => return Err(err.into()),
|
||||||
};
|
};
|
||||||
|
|
||||||
if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') {
|
if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') {
|
||||||
let frame_len = frame.len();
|
let len = msg.len();
|
||||||
let protocol = frame.split_to(frame_len - 1);
|
let name = msg.split_to(len - 1);
|
||||||
Ok(Async::Ready(Some(
|
Ok(Async::Ready(Some(
|
||||||
DialerToListenerMessage::ProtocolRequest { name: protocol },
|
Request::Protocol { name },
|
||||||
)))
|
)))
|
||||||
} else if frame == b"ls\n"[..] {
|
} else if msg == MSG_LS {
|
||||||
Ok(Async::Ready(Some(
|
Ok(Async::Ready(Some(
|
||||||
DialerToListenerMessage::ProtocolsListRequest,
|
Request::ListProtocols,
|
||||||
)))
|
)))
|
||||||
} else {
|
} else {
|
||||||
Err(MultistreamSelectError::UnknownMessage)
|
Err(MultistreamSelectError::UnknownMessage)
|
||||||
@ -124,16 +119,17 @@ where
|
|||||||
|
|
||||||
/// Future, returned by `Listener::new` which performs the handshake and returns
|
/// Future, returned by `Listener::new` which performs the handshake and returns
|
||||||
/// the `Listener` if successful.
|
/// the `Listener` if successful.
|
||||||
pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
|
pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N> {
|
||||||
inner: ListenerFutureState<T, N>
|
inner: ListenerFutureState<T>,
|
||||||
|
_protocol_name: marker::PhantomData<N>,
|
||||||
}
|
}
|
||||||
|
|
||||||
enum ListenerFutureState<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
|
enum ListenerFutureState<T: AsyncRead + AsyncWrite> {
|
||||||
Await {
|
Await {
|
||||||
inner: StreamFuture<LengthDelimited<T, MessageEncoder<N>>>
|
inner: StreamFuture<LengthDelimited<T>>
|
||||||
},
|
},
|
||||||
Reply {
|
Reply {
|
||||||
sender: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
|
sender: sink::Send<LengthDelimited<T>>
|
||||||
},
|
},
|
||||||
Undefined
|
Undefined
|
||||||
}
|
}
|
||||||
@ -155,12 +151,14 @@ impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N>
|
|||||||
}
|
}
|
||||||
Err((e, _)) => return Err(MultistreamSelectError::from(e))
|
Err((e, _)) => return Err(MultistreamSelectError::from(e))
|
||||||
};
|
};
|
||||||
if msg.as_ref().map(|b| &b[..]) != Some(MULTISTREAM_PROTOCOL_WITH_LF) {
|
if msg.as_ref().map(|b| &b[..]) != Some(MSG_MULTISTREAM_1_0) {
|
||||||
debug!("failed handshake; received: {:?}", msg);
|
debug!("Unexpected message: {:?}", msg);
|
||||||
return Err(MultistreamSelectError::FailedHandshake)
|
return Err(MultistreamSelectError::FailedHandshake)
|
||||||
}
|
}
|
||||||
trace!("sending back /multistream/<version> to finish the handshake");
|
trace!("sending back /multistream/<version> to finish the handshake");
|
||||||
let sender = socket.send(Message::Header);
|
let mut frame = BytesMut::new();
|
||||||
|
Header::Multistream10.encode(&mut frame);
|
||||||
|
let sender = socket.send(frame.freeze());
|
||||||
self.inner = ListenerFutureState::Reply { sender }
|
self.inner = ListenerFutureState::Reply { sender }
|
||||||
}
|
}
|
||||||
ListenerFutureState::Reply { mut sender } => {
|
ListenerFutureState::Reply { mut sender } => {
|
||||||
@ -171,70 +169,13 @@ impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N>
|
|||||||
return Ok(Async::NotReady)
|
return Ok(Async::NotReady)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return Ok(Async::Ready(Listener { inner: listener }))
|
return Ok(Async::Ready(Listener {
|
||||||
|
inner: listener,
|
||||||
|
_protocol_name: marker::PhantomData
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
ListenerFutureState::Undefined => panic!("ListenerFutureState::poll called after completion")
|
ListenerFutureState::Undefined =>
|
||||||
}
|
panic!("ListenerFutureState::poll called after completion")
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// tokio-codec `Encoder` handling `ListenerToDialerMessage` values.
|
|
||||||
struct MessageEncoder<N>(std::marker::PhantomData<N>);
|
|
||||||
|
|
||||||
enum Message<N> {
|
|
||||||
Header,
|
|
||||||
Body(ListenerToDialerMessage<N>)
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
|
|
||||||
type Item = Message<N>;
|
|
||||||
type Error = MultistreamSelectError;
|
|
||||||
|
|
||||||
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
|
|
||||||
match item {
|
|
||||||
Message::Header => {
|
|
||||||
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
|
|
||||||
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
|
|
||||||
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Message::Body(ListenerToDialerMessage::ProtocolAck { name }) => {
|
|
||||||
if !name.as_ref().starts_with(b"/") {
|
|
||||||
return Err(MultistreamSelectError::WrongProtocolName)
|
|
||||||
}
|
|
||||||
let len = name.as_ref().len() + 1; // + 1 for \n
|
|
||||||
if len > std::u16::MAX as usize {
|
|
||||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
|
|
||||||
}
|
|
||||||
Uvi::<usize>::default().encode(len, dest)?;
|
|
||||||
dest.reserve(len);
|
|
||||||
dest.put(name.as_ref());
|
|
||||||
dest.put(&b"\n"[..]);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Message::Body(ListenerToDialerMessage::ProtocolsListResponse { list }) => {
|
|
||||||
let mut buf = encode::usize_buffer();
|
|
||||||
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
|
|
||||||
for e in &list {
|
|
||||||
if e.as_ref().len() + 1 > std::u16::MAX as usize {
|
|
||||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
|
|
||||||
}
|
|
||||||
out_msg.extend(encode::usize(e.as_ref().len() + 1, &mut buf)); // +1 for '\n'
|
|
||||||
out_msg.extend_from_slice(e.as_ref());
|
|
||||||
out_msg.push(b'\n')
|
|
||||||
}
|
|
||||||
let len = encode::usize(out_msg.len(), &mut buf);
|
|
||||||
dest.reserve(len.len() + out_msg.len());
|
|
||||||
dest.put(len);
|
|
||||||
dest.put(out_msg);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
Message::Body(ListenerToDialerMessage::NotAvailable) => {
|
|
||||||
Uvi::<usize>::default().encode(3, dest)?;
|
|
||||||
dest.reserve(3);
|
|
||||||
dest.put(&b"na\n"[..]);
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -242,12 +183,12 @@ impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use super::*;
|
||||||
use tokio::runtime::current_thread::Runtime;
|
use tokio::runtime::current_thread::Runtime;
|
||||||
use tokio_tcp::{TcpListener, TcpStream};
|
use tokio_tcp::{TcpListener, TcpStream};
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use futures::Future;
|
use futures::Future;
|
||||||
use futures::{Sink, Stream};
|
use futures::{Sink, Stream};
|
||||||
use crate::protocol::{Dialer, Listener, ListenerToDialerMessage, MultistreamSelectError};
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn wrong_proto_name() {
|
fn wrong_proto_name() {
|
||||||
@ -260,8 +201,8 @@ mod tests {
|
|||||||
.map_err(|(e, _)| e.into())
|
.map_err(|(e, _)| e.into())
|
||||||
.and_then(move |(connec, _)| Listener::listen(connec.unwrap()))
|
.and_then(move |(connec, _)| Listener::listen(connec.unwrap()))
|
||||||
.and_then(|listener| {
|
.and_then(|listener| {
|
||||||
let proto_name = Bytes::from("invalid-proto");
|
let name = Bytes::from("invalid-proto");
|
||||||
listener.send(ListenerToDialerMessage::ProtocolAck { name: proto_name })
|
listener.send(Response::Protocol { name })
|
||||||
});
|
});
|
||||||
|
|
||||||
let client = TcpStream::connect(&listener_addr)
|
let client = TcpStream::connect(&listener_addr)
|
||||||
@ -270,7 +211,7 @@ mod tests {
|
|||||||
|
|
||||||
let mut rt = Runtime::new().unwrap();
|
let mut rt = Runtime::new().unwrap();
|
||||||
match rt.block_on(server.join(client)) {
|
match rt.block_on(server.join(client)) {
|
||||||
Err(MultistreamSelectError::WrongProtocolName) => (),
|
Err(MultistreamSelectError::InvalidProtocolName) => (),
|
||||||
_ => panic!(),
|
_ => panic!(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -20,47 +20,125 @@
|
|||||||
|
|
||||||
//! Contains lower-level structs to handle the multistream protocol.
|
//! Contains lower-level structs to handle the multistream protocol.
|
||||||
|
|
||||||
|
const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
|
||||||
|
const MSG_PROTOCOL_NA: &[u8] = b"na\n";
|
||||||
|
const MSG_LS: &[u8] = b"ls\n";
|
||||||
|
|
||||||
mod dialer;
|
mod dialer;
|
||||||
mod error;
|
mod error;
|
||||||
mod listener;
|
mod listener;
|
||||||
|
|
||||||
const MULTISTREAM_PROTOCOL_WITH_LF: &[u8] = b"/multistream/1.0.0\n";
|
|
||||||
|
|
||||||
pub use self::dialer::{Dialer, DialerFuture};
|
pub use self::dialer::{Dialer, DialerFuture};
|
||||||
pub use self::error::MultistreamSelectError;
|
pub use self::error::MultistreamSelectError;
|
||||||
pub use self::listener::{Listener, ListenerFuture};
|
pub use self::listener::{Listener, ListenerFuture};
|
||||||
|
|
||||||
|
use bytes::{BytesMut, BufMut};
|
||||||
|
use unsigned_varint as uvi;
|
||||||
|
|
||||||
|
pub enum Header {
|
||||||
|
Multistream10
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Header {
|
||||||
|
fn encode(&self, dest: &mut BytesMut) {
|
||||||
|
match self {
|
||||||
|
Header::Multistream10 => {
|
||||||
|
dest.reserve(MSG_MULTISTREAM_1_0.len());
|
||||||
|
dest.put(MSG_MULTISTREAM_1_0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Message sent from the dialer to the listener.
|
/// Message sent from the dialer to the listener.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub enum DialerToListenerMessage<N> {
|
pub enum Request<N> {
|
||||||
/// The dialer wants us to use a protocol.
|
/// The dialer wants us to use a protocol.
|
||||||
///
|
///
|
||||||
/// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start
|
/// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start
|
||||||
/// communicating in the new protocol.
|
/// communicating in the new protocol.
|
||||||
ProtocolRequest {
|
Protocol {
|
||||||
/// Name of the protocol.
|
/// Name of the protocol.
|
||||||
name: N
|
name: N
|
||||||
},
|
},
|
||||||
|
|
||||||
/// The dialer requested the list of protocols that the listener supports.
|
/// The dialer requested the list of protocols that the listener supports.
|
||||||
ProtocolsListRequest,
|
ListProtocols,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<N: AsRef<[u8]>> Request<N> {
|
||||||
|
fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> {
|
||||||
|
match self {
|
||||||
|
Request::Protocol { name } => {
|
||||||
|
if !name.as_ref().starts_with(b"/") {
|
||||||
|
return Err(MultistreamSelectError::InvalidProtocolName)
|
||||||
|
}
|
||||||
|
let len = name.as_ref().len() + 1; // + 1 for \n
|
||||||
|
dest.reserve(len);
|
||||||
|
dest.put(name.as_ref());
|
||||||
|
dest.put(&b"\n"[..]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Request::ListProtocols => {
|
||||||
|
dest.reserve(MSG_LS.len());
|
||||||
|
dest.put(MSG_LS);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Message sent from the listener to the dialer.
|
/// Message sent from the listener to the dialer.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
pub enum ListenerToDialerMessage<N> {
|
pub enum Response<N> {
|
||||||
/// The protocol requested by the dialer is accepted. The socket immediately starts using the
|
/// The protocol requested by the dialer is accepted. The socket immediately starts using the
|
||||||
/// new protocol.
|
/// new protocol.
|
||||||
ProtocolAck { name: N },
|
Protocol { name: N },
|
||||||
|
|
||||||
/// The protocol requested by the dialer is not supported or available.
|
/// The protocol requested by the dialer is not supported or available.
|
||||||
NotAvailable,
|
ProtocolNotAvailable,
|
||||||
|
|
||||||
/// Response to the request for the list of protocols.
|
/// Response to the request for the list of protocols.
|
||||||
ProtocolsListResponse {
|
SupportedProtocols {
|
||||||
/// The list of protocols.
|
/// The list of protocols.
|
||||||
// TODO: use some sort of iterator
|
// TODO: use some sort of iterator
|
||||||
list: Vec<N>,
|
protocols: Vec<N>,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<N: AsRef<[u8]>> Response<N> {
|
||||||
|
fn encode(&self, dest: &mut BytesMut) -> Result<(), MultistreamSelectError> {
|
||||||
|
match self {
|
||||||
|
Response::Protocol { name } => {
|
||||||
|
if !name.as_ref().starts_with(b"/") {
|
||||||
|
return Err(MultistreamSelectError::InvalidProtocolName)
|
||||||
|
}
|
||||||
|
let len = name.as_ref().len() + 1; // + 1 for \n
|
||||||
|
dest.reserve(len);
|
||||||
|
dest.put(name.as_ref());
|
||||||
|
dest.put(&b"\n"[..]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Response::SupportedProtocols { protocols } => {
|
||||||
|
let mut buf = uvi::encode::usize_buffer();
|
||||||
|
let mut out_msg = Vec::from(uvi::encode::usize(protocols.len(), &mut buf));
|
||||||
|
for p in protocols {
|
||||||
|
out_msg.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); // +1 for '\n'
|
||||||
|
out_msg.extend_from_slice(p.as_ref());
|
||||||
|
out_msg.push(b'\n')
|
||||||
|
}
|
||||||
|
dest.reserve(out_msg.len());
|
||||||
|
dest.put(out_msg);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Response::ProtocolNotAvailable => {
|
||||||
|
dest.reserve(MSG_PROTOCOL_NA.len());
|
||||||
|
dest.put(MSG_PROTOCOL_NA);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@
|
|||||||
|
|
||||||
use crate::ProtocolChoiceError;
|
use crate::ProtocolChoiceError;
|
||||||
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
|
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
|
||||||
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage};
|
use crate::protocol::{Dialer, Request, Listener, Response};
|
||||||
use crate::{dialer_select_proto, listener_select_proto};
|
use crate::{dialer_select_proto, listener_select_proto};
|
||||||
use futures::prelude::*;
|
use futures::prelude::*;
|
||||||
use tokio::runtime::current_thread::Runtime;
|
use tokio::runtime::current_thread::Runtime;
|
||||||
@ -56,23 +56,23 @@ fn negotiate_with_self_succeeds() {
|
|||||||
.and_then(|l| l.into_future().map_err(|(e, _)| e))
|
.and_then(|l| l.into_future().map_err(|(e, _)| e))
|
||||||
.and_then(|(msg, rest)| {
|
.and_then(|(msg, rest)| {
|
||||||
let proto = match msg {
|
let proto = match msg {
|
||||||
Some(DialerToListenerMessage::ProtocolRequest { name }) => name,
|
Some(Request::Protocol { name }) => name,
|
||||||
_ => panic!(),
|
_ => panic!(),
|
||||||
};
|
};
|
||||||
rest.send(ListenerToDialerMessage::ProtocolAck { name: proto })
|
rest.send(Response::Protocol { name: proto })
|
||||||
});
|
});
|
||||||
|
|
||||||
let client = TcpStream::connect(&listener_addr)
|
let client = TcpStream::connect(&listener_addr)
|
||||||
.from_err()
|
.from_err()
|
||||||
.and_then(move |stream| Dialer::dial(stream))
|
.and_then(move |stream| Dialer::dial(stream))
|
||||||
.and_then(move |dialer| {
|
.and_then(move |dialer| {
|
||||||
let p = b"/hello/1.0.0";
|
let name = b"/hello/1.0.0";
|
||||||
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
|
dialer.send(Request::Protocol { name })
|
||||||
})
|
})
|
||||||
.and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e))
|
.and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e))
|
||||||
.and_then(move |(msg, _)| {
|
.and_then(move |(msg, _)| {
|
||||||
let proto = match msg {
|
let proto = match msg {
|
||||||
Some(ListenerToDialerMessage::ProtocolAck { name }) => name,
|
Some(Response::Protocol { name }) => name,
|
||||||
_ => panic!(),
|
_ => panic!(),
|
||||||
};
|
};
|
||||||
assert_eq!(proto, "/hello/1.0.0");
|
assert_eq!(proto, "/hello/1.0.0");
|
||||||
|
Loading…
x
Reference in New Issue
Block a user