multistream-select: Less allocations. (#800)

This commit is contained in:
Toralf Wittner
2019-01-09 15:09:35 +01:00
committed by GitHub
parent aedf9c0c31
commit f1959252b7
9 changed files with 467 additions and 372 deletions

View File

@ -224,6 +224,7 @@ type NameWrapIter<I> =
std::iter::Map<I, fn(<I as Iterator>::Item) -> NameWrap<<I as Iterator>::Item>>; std::iter::Map<I, fn(<I as Iterator>::Item) -> NameWrap<<I as Iterator>::Item>>;
/// Wrapper type to expose an `AsRef<[u8]>` impl for all types implementing `ProtocolName`. /// Wrapper type to expose an `AsRef<[u8]>` impl for all types implementing `ProtocolName`.
#[derive(Clone)]
struct NameWrap<N>(N); struct NameWrap<N>(N);
impl<N: ProtocolName> AsRef<[u8]> for NameWrap<N> { impl<N: ProtocolName> AsRef<[u8]> for NameWrap<N> {

View File

@ -91,7 +91,7 @@ impl<T: AsRef<[u8]>> ProtocolName for T {
/// or both. /// or both.
pub trait UpgradeInfo { pub trait UpgradeInfo {
/// Opaque type representing a negotiable protocol. /// Opaque type representing a negotiable protocol.
type Info: ProtocolName; type Info: ProtocolName + Clone;
/// Iterator returned by `protocol_info`. /// Iterator returned by `protocol_info`.
type InfoIter: IntoIterator<Item = Self::Info>; type InfoIter: IntoIterator<Item = Self::Info>;

View File

@ -21,9 +21,13 @@
//! Contains the `dialer_select_proto` code, which allows selecting a protocol thanks to //! Contains the `dialer_select_proto` code, which allows selecting a protocol thanks to
//! `multistream-select` for the dialer. //! `multistream-select` for the dialer.
use bytes::Bytes; use futures::{future::Either, prelude::*, stream::StreamFuture};
use futures::{future::Either, prelude::*, sink, stream::StreamFuture}; use crate::protocol::{
use crate::protocol::{Dialer, DialerFuture, DialerToListenerMessage, ListenerToDialerMessage}; Dialer,
DialerFuture,
DialerToListenerMessage,
ListenerToDialerMessage
};
use log::trace; use log::trace;
use std::mem; use std::mem;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
@ -44,7 +48,6 @@ pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPa
/// remote, and the protocol name that we passed (so that you don't have to clone the name). On /// remote, and the protocol name that we passed (so that you don't have to clone the name). On
/// success, the function returns the identifier (of type `P`), plus the socket which now uses that /// success, the function returns the identifier (of type `P`), plus the socket which now uses that
/// chosen protocol. /// chosen protocol.
#[inline]
pub fn dialer_select_proto<R, I>(inner: R, protocols: I) -> DialerSelectFuture<R, I::IntoIter> pub fn dialer_select_proto<R, I>(inner: R, protocols: I) -> DialerSelectFuture<R, I::IntoIter>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
@ -64,12 +67,13 @@ where
/// ///
/// Same as `dialer_select_proto`. Tries protocols one by one. The iterator doesn't need to produce /// Same as `dialer_select_proto`. Tries protocols one by one. The iterator doesn't need to produce
/// match functions, because it's not needed. /// match functions, because it's not needed.
pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I,) -> DialerSelectSeq<R, I> pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I) -> DialerSelectSeq<R, I::IntoIter>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
I: Iterator, I: IntoIterator,
I::Item: AsRef<[u8]> I::Item: AsRef<[u8]>
{ {
let protocols = protocols.into_iter();
DialerSelectSeq { DialerSelectSeq {
inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols } inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
} }
@ -78,26 +82,37 @@ where
/// Future, returned by `dialer_select_proto_serial` which selects a protocol /// Future, returned by `dialer_select_proto_serial` which selects a protocol
/// and dialer sequentially. /// and dialer sequentially.
pub struct DialerSelectSeq<R: AsyncRead + AsyncWrite, I: Iterator> { pub struct DialerSelectSeq<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
inner: DialerSelectSeqState<R, I> inner: DialerSelectSeqState<R, I>
} }
enum DialerSelectSeqState<R: AsyncRead + AsyncWrite, I: Iterator> { enum DialerSelectSeqState<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
AwaitDialer { AwaitDialer {
dialer_fut: DialerFuture<R>, dialer_fut: DialerFuture<R, I::Item>,
protocols: I protocols: I
}, },
NextProtocol { NextProtocol {
dialer: Dialer<R>, dialer: Dialer<R, I::Item>,
proto_name: I::Item,
protocols: I protocols: I
}, },
SendProtocol { FlushProtocol {
sender: sink::Send<Dialer<R>>, dialer: Dialer<R, I::Item>,
proto_name: I::Item, proto_name: I::Item,
protocols: I protocols: I
}, },
AwaitProtocol { AwaitProtocol {
stream: StreamFuture<Dialer<R>>, stream: StreamFuture<Dialer<R, I::Item>>,
proto_name: I::Item, proto_name: I::Item,
protocols: I protocols: I
}, },
@ -106,9 +121,9 @@ enum DialerSelectSeqState<R: AsyncRead + AsyncWrite, I: Iterator> {
impl<R, I> Future for DialerSelectSeq<R, I> impl<R, I> Future for DialerSelectSeq<R, I>
where where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]> + Clone
{ {
type Item = (I::Item, R); type Item = (I::Item, R);
type Error = ProtocolChoiceError; type Error = ProtocolChoiceError;
@ -116,7 +131,7 @@ where
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop { loop {
match mem::replace(&mut self.inner, DialerSelectSeqState::Undefined) { match mem::replace(&mut self.inner, DialerSelectSeqState::Undefined) {
DialerSelectSeqState::AwaitDialer { mut dialer_fut, protocols } => { DialerSelectSeqState::AwaitDialer { mut dialer_fut, mut protocols } => {
let dialer = match dialer_fut.poll()? { let dialer = match dialer_fut.poll()? {
Async::Ready(d) => d, Async::Ready(d) => d,
Async::NotReady => { Async::NotReady => {
@ -124,42 +139,57 @@ where
return Ok(Async::NotReady) return Ok(Async::NotReady)
} }
}; };
self.inner = DialerSelectSeqState::NextProtocol { dialer, protocols } let proto_name = protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?;
self.inner = DialerSelectSeqState::NextProtocol {
dialer,
protocols,
proto_name
} }
DialerSelectSeqState::NextProtocol { dialer, mut protocols } => { }
let proto_name = DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => {
protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?; trace!("sending {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest { let req = DialerToListenerMessage::ProtocolRequest {
name: Bytes::from(proto_name.as_ref()) name: proto_name.clone()
}; };
trace!("sending {:?}", req); match dialer.start_send(req)? {
let sender = dialer.send(req); AsyncSink::Ready => {
self.inner = DialerSelectSeqState::SendProtocol { self.inner = DialerSelectSeqState::FlushProtocol {
sender, dialer,
proto_name, proto_name,
protocols protocols
} }
} }
DialerSelectSeqState::SendProtocol { mut sender, proto_name, protocols } => { AsyncSink::NotReady(_) => {
let dialer = match sender.poll()? { self.inner = DialerSelectSeqState::NextProtocol {
Async::Ready(d) => d, dialer,
Async::NotReady => { protocols,
self.inner = DialerSelectSeqState::SendProtocol { proto_name
sender,
proto_name,
protocols
}; };
return Ok(Async::NotReady) return Ok(Async::NotReady)
} }
}; }
}
DialerSelectSeqState::FlushProtocol { mut dialer, proto_name, protocols } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
let stream = dialer.into_future(); let stream = dialer.into_future();
self.inner = DialerSelectSeqState::AwaitProtocol { self.inner = DialerSelectSeqState::AwaitProtocol {
stream, stream,
proto_name, proto_name,
protocols protocols
};
} }
DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, protocols } => { }
Async::NotReady => {
self.inner = DialerSelectSeqState::FlushProtocol {
dialer,
proto_name,
protocols
};
return Ok(Async::NotReady)
}
}
}
DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, mut protocols } => {
let (m, r) = match stream.poll() { let (m, r) = match stream.poll() {
Ok(Async::Ready(x)) => x, Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
@ -178,9 +208,15 @@ where
if name.as_ref() == proto_name.as_ref() => if name.as_ref() == proto_name.as_ref() =>
{ {
return Ok(Async::Ready((proto_name, r.into_inner()))) return Ok(Async::Ready((proto_name, r.into_inner())))
}, }
ListenerToDialerMessage::NotAvailable => { ListenerToDialerMessage::NotAvailable => {
self.inner = DialerSelectSeqState::NextProtocol { dialer: r, protocols } let proto_name = protocols.next()
.ok_or(ProtocolChoiceError::NoProtocolFound)?;
self.inner = DialerSelectSeqState::NextProtocol {
dialer: r,
protocols,
proto_name
}
} }
_ => return Err(ProtocolChoiceError::UnexpectedMessage) _ => return Err(ProtocolChoiceError::UnexpectedMessage)
} }
@ -192,17 +228,17 @@ where
} }
} }
/// Helps selecting a protocol amongst the ones supported. /// Helps selecting a protocol amongst the ones supported.
/// ///
/// Same as `dialer_select_proto`. Queries the list of supported protocols from the remote, then /// Same as `dialer_select_proto`. Queries the list of supported protocols from the remote, then
/// chooses the most appropriate one. /// chooses the most appropriate one.
pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I> pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I::IntoIter>
where where
I: Iterator, R: AsyncRead + AsyncWrite,
I::Item: AsRef<[u8]>, I: IntoIterator,
R: AsyncRead + AsyncWrite I::Item: AsRef<[u8]>
{ {
let protocols = protocols.into_iter();
DialerSelectPar { DialerSelectPar {
inner: DialerSelectParState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols } inner: DialerSelectParState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
} }
@ -212,29 +248,47 @@ where
/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in /// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in
/// parellel, by first requesting the liste of protocols supported by the remote endpoint and /// parellel, by first requesting the liste of protocols supported by the remote endpoint and
/// then selecting the most appropriate one by applying a match predicate to the result. /// then selecting the most appropriate one by applying a match predicate to the result.
pub struct DialerSelectPar<R: AsyncRead + AsyncWrite, I: Iterator> { pub struct DialerSelectPar<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
inner: DialerSelectParState<R, I> inner: DialerSelectParState<R, I>
} }
enum DialerSelectParState<R: AsyncRead + AsyncWrite, I: Iterator> { enum DialerSelectParState<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
AwaitDialer { AwaitDialer {
dialer_fut: DialerFuture<R>, dialer_fut: DialerFuture<R, I::Item>,
protocols: I protocols: I
}, },
SendRequest { ProtocolList {
sender: sink::Send<Dialer<R>>, dialer: Dialer<R, I::Item>,
protocols: I protocols: I
}, },
AwaitResponse { FlushListRequest {
stream: StreamFuture<Dialer<R>>, dialer: Dialer<R, I::Item>,
protocols: I protocols: I
}, },
SendProtocol { AwaitListResponse {
sender: sink::Send<Dialer<R>>, stream: StreamFuture<Dialer<R, I::Item>>,
protocols: I,
},
Protocol {
dialer: Dialer<R, I::Item>,
proto_name: I::Item
},
FlushProtocol {
dialer: Dialer<R, I::Item>,
proto_name: I::Item proto_name: I::Item
}, },
AwaitProtocol { AwaitProtocol {
stream: StreamFuture<Dialer<R>>, stream: StreamFuture<Dialer<R, I::Item>>,
proto_name: I::Item proto_name: I::Item
}, },
Undefined Undefined
@ -242,9 +296,9 @@ enum DialerSelectParState<R: AsyncRead + AsyncWrite, I: Iterator> {
impl<R, I> Future for DialerSelectPar<R, I> impl<R, I> Future for DialerSelectPar<R, I>
where where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]> + Clone
{ {
type Item = (I::Item, R); type Item = (I::Item, R);
type Error = ProtocolChoiceError; type Error = ProtocolChoiceError;
@ -253,41 +307,63 @@ where
loop { loop {
match mem::replace(&mut self.inner, DialerSelectParState::Undefined) { match mem::replace(&mut self.inner, DialerSelectParState::Undefined) {
DialerSelectParState::AwaitDialer { mut dialer_fut, protocols } => { DialerSelectParState::AwaitDialer { mut dialer_fut, protocols } => {
let dialer = match dialer_fut.poll()? { match dialer_fut.poll()? {
Async::Ready(d) => d, Async::Ready(dialer) => {
self.inner = DialerSelectParState::ProtocolList { dialer, protocols }
}
Async::NotReady => { Async::NotReady => {
self.inner = DialerSelectParState::AwaitDialer { dialer_fut, protocols }; self.inner = DialerSelectParState::AwaitDialer { dialer_fut, protocols };
return Ok(Async::NotReady) return Ok(Async::NotReady)
} }
};
trace!("requesting protocols list");
let sender = dialer.send(DialerToListenerMessage::ProtocolsListRequest);
self.inner = DialerSelectParState::SendRequest { sender, protocols };
} }
DialerSelectParState::SendRequest { mut sender, protocols } => { }
let dialer = match sender.poll()? { DialerSelectParState::ProtocolList { mut dialer, protocols } => {
Async::Ready(d) => d, trace!("requesting protocols list");
Async::NotReady => { match dialer.start_send(DialerToListenerMessage::ProtocolsListRequest)? {
self.inner = DialerSelectParState::SendRequest { sender, protocols }; AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushListRequest {
dialer,
protocols
}
}
AsyncSink::NotReady(_) => {
self.inner = DialerSelectParState::ProtocolList { dialer, protocols };
return Ok(Async::NotReady) return Ok(Async::NotReady)
} }
};
let stream = dialer.into_future();
self.inner = DialerSelectParState::AwaitResponse { stream, protocols };
} }
DialerSelectParState::AwaitResponse { mut stream, protocols } => { }
let (m, d) = match stream.poll() { DialerSelectParState::FlushListRequest { mut dialer, protocols } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
self.inner = DialerSelectParState::AwaitListResponse {
stream: dialer.into_future(),
protocols
}
}
Async::NotReady => {
self.inner = DialerSelectParState::FlushListRequest {
dialer,
protocols
};
return Ok(Async::NotReady)
}
}
}
DialerSelectParState::AwaitListResponse { mut stream, protocols } => {
let (resp, dialer) = match stream.poll() {
Ok(Async::Ready(x)) => x, Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.inner = DialerSelectParState::AwaitResponse { stream, protocols }; self.inner = DialerSelectParState::AwaitListResponse { stream, protocols };
return Ok(Async::NotReady) return Ok(Async::NotReady)
} }
Err((e, _)) => return Err(ProtocolChoiceError::from(e)) Err((e, _)) => return Err(ProtocolChoiceError::from(e))
}; };
trace!("protocols list response: {:?}", m); trace!("protocols list response: {:?}", resp);
let list = match m { let list =
Some(ListenerToDialerMessage::ProtocolsListResponse { list }) => list, if let Some(ListenerToDialerMessage::ProtocolsListResponse { list }) = resp {
_ => return Err(ProtocolChoiceError::UnexpectedMessage), list
} else {
return Err(ProtocolChoiceError::UnexpectedMessage)
}; };
let mut found = None; let mut found = None;
for local_name in protocols { for local_name in protocols {
@ -302,47 +378,52 @@ where
} }
} }
let proto_name = found.ok_or(ProtocolChoiceError::NoProtocolFound)?; let proto_name = found.ok_or(ProtocolChoiceError::NoProtocolFound)?;
trace!("sending {:?}", proto_name.as_ref()); self.inner = DialerSelectParState::Protocol { dialer, proto_name }
let sender = d.send(DialerToListenerMessage::ProtocolRequest {
name: Bytes::from(proto_name.as_ref())
});
self.inner = DialerSelectParState::SendProtocol { sender, proto_name };
} }
DialerSelectParState::SendProtocol { mut sender, proto_name } => { DialerSelectParState::Protocol { mut dialer, proto_name } => {
let dialer = match sender.poll()? { trace!("requesting protocol: {:?}", proto_name.as_ref());
Async::Ready(d) => d, let req = DialerToListenerMessage::ProtocolRequest {
Async::NotReady => { name: proto_name.clone()
self.inner = DialerSelectParState::SendProtocol {
sender,
proto_name
}; };
match dialer.start_send(req)? {
AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }
}
AsyncSink::NotReady(_) => {
self.inner = DialerSelectParState::Protocol { dialer, proto_name };
return Ok(Async::NotReady) return Ok(Async::NotReady)
} }
}; }
let stream = dialer.into_future(); }
DialerSelectParState::FlushProtocol { mut dialer, proto_name } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
self.inner = DialerSelectParState::AwaitProtocol { self.inner = DialerSelectParState::AwaitProtocol {
stream, stream: dialer.into_future(),
proto_name proto_name
}; }
}
Async::NotReady => {
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name };
return Ok(Async::NotReady)
}
}
} }
DialerSelectParState::AwaitProtocol { mut stream, proto_name } => { DialerSelectParState::AwaitProtocol { mut stream, proto_name } => {
let (m, r) = match stream.poll() { let (resp, dialer) = match stream.poll() {
Ok(Async::Ready(x)) => x, Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => { Ok(Async::NotReady) => {
self.inner = DialerSelectParState::AwaitProtocol { self.inner = DialerSelectParState::AwaitProtocol { stream, proto_name };
stream,
proto_name
};
return Ok(Async::NotReady) return Ok(Async::NotReady)
} }
Err((e, _)) => return Err(ProtocolChoiceError::from(e)) Err((e, _)) => return Err(ProtocolChoiceError::from(e))
}; };
trace!("received {:?}", m); trace!("received {:?}", resp);
match m { match resp {
Some(ListenerToDialerMessage::ProtocolAck { ref name }) Some(ListenerToDialerMessage::ProtocolAck { ref name })
if name.as_ref() == proto_name.as_ref() => if name.as_ref() == proto_name.as_ref() =>
{ {
return Ok(Async::Ready((proto_name, r.into_inner()))) return Ok(Async::Ready((proto_name, dialer.into_inner())))
} }
_ => return Err(ProtocolChoiceError::UnexpectedMessage) _ => return Err(ProtocolChoiceError::UnexpectedMessage)
} }
@ -354,4 +435,3 @@ where
} }
} }

View File

@ -18,21 +18,22 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
use bytes::Bytes;
use futures::{Async, Poll, Sink, StartSend, Stream}; use futures::{Async, Poll, Sink, StartSend, Stream};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::{io::{Error as IoError, ErrorKind as IoErrorKind}, marker::PhantomData, u16}; use std::{io, u16};
use tokio_codec::FramedWrite; use tokio_codec::{Encoder, FramedWrite};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec::UviBytes; use unsigned_varint::decode;
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read /// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
/// and write unsigned-varint prefixed frames. /// and write unsigned-varint prefixed frames.
/// ///
/// We purposely only support a frame length of under 64kiB. Frames mostly consist /// We purposely only support a frame length of under 64kiB. Frames mostly consist
/// in a short protocol name, which is highly unlikely to be more than 64kiB long. /// in a short protocol name, which is highly unlikely to be more than 64kiB long.
pub struct LengthDelimited<I, S> { pub struct LengthDelimited<R, C> {
// The inner socket where data is pulled from. // The inner socket where data is pulled from.
inner: FramedWrite<S, UviBytes>, inner: FramedWrite<R, C>,
// Intermediary buffer where we put either the length of the next frame of data, or the frame // Intermediary buffer where we put either the length of the next frame of data, or the frame
// of data itself before it is returned. // of data itself before it is returned.
// Must always contain enough space to read data from `inner`. // Must always contain enough space to read data from `inner`.
@ -40,8 +41,7 @@ pub struct LengthDelimited<I, S> {
// Number of bytes within `internal_buffer` that contain valid data. // Number of bytes within `internal_buffer` that contain valid data.
internal_buffer_pos: usize, internal_buffer_pos: usize,
// State of the decoder. // State of the decoder.
state: State, state: State
marker: PhantomData<I>,
} }
#[derive(Debug, Copy, Clone, PartialEq, Eq)] #[derive(Debug, Copy, Clone, PartialEq, Eq)]
@ -52,24 +52,21 @@ enum State {
ReadingData { frame_len: u16 }, ReadingData { frame_len: u16 },
} }
impl<I, S> LengthDelimited<I, S> impl<R, C> LengthDelimited<R, C>
where where
S: AsyncWrite R: AsyncWrite,
C: Encoder
{ {
pub fn new(inner: S) -> LengthDelimited<I, S> { pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
let mut encoder = UviBytes::default();
encoder.set_max_len(usize::from(u16::MAX));
LengthDelimited { LengthDelimited {
inner: FramedWrite::new(inner, encoder), inner: FramedWrite::new(inner, codec),
internal_buffer: { internal_buffer: {
let mut v = SmallVec::new(); let mut v = SmallVec::new();
v.push(0); v.push(0);
v v
}, },
internal_buffer_pos: 0, internal_buffer_pos: 0,
state: State::ReadingLength, state: State::ReadingLength
marker: PhantomData,
} }
} }
@ -85,20 +82,19 @@ where
/// the modifiers provided by the `futures` crate) will always leave the object in a state in /// the modifiers provided by the `futures` crate) will always leave the object in a state in
/// which `into_inner()` will not panic. /// which `into_inner()` will not panic.
#[inline] #[inline]
pub fn into_inner(self) -> S { pub fn into_inner(self) -> R {
assert_eq!(self.state, State::ReadingLength); assert_eq!(self.state, State::ReadingLength);
assert_eq!(self.internal_buffer_pos, 0); assert_eq!(self.internal_buffer_pos, 0);
self.inner.into_inner() self.inner.into_inner()
} }
} }
impl<I, S> Stream for LengthDelimited<I, S> impl<R, C> Stream for LengthDelimited<R, C>
where where
S: AsyncRead, R: AsyncRead
I: for<'r> From<&'r [u8]>,
{ {
type Item = I; type Item = Bytes;
type Error = IoError; type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop { loop {
@ -107,23 +103,21 @@ where
match self.state { match self.state {
State::ReadingLength => { State::ReadingLength => {
match self.inner let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
.get_mut() match self.inner.get_mut().read(slice) {
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
Ok(0) => { Ok(0) => {
// EOF // EOF
if self.internal_buffer_pos == 0 { if self.internal_buffer_pos == 0 {
return Ok(Async::Ready(None)); return Ok(Async::Ready(None));
} else { } else {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof")); return Err(io::ErrorKind::UnexpectedEof.into());
} }
} }
Ok(n) => { Ok(n) => {
debug_assert_eq!(n, 1); debug_assert_eq!(n, 1);
self.internal_buffer_pos += n; self.internal_buffer_pos += n;
} }
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady); return Ok(Async::NotReady);
} }
Err(err) => { Err(err) => {
@ -136,7 +130,10 @@ where
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 { if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 {
// End of length prefix. Most of the time we will switch to reading data, // End of length prefix. Most of the time we will switch to reading data,
// but we need to handle a few corner cases first. // but we need to handle a few corner cases first.
let frame_len = decode_length_prefix(&self.internal_buffer); let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| {
log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;
if frame_len >= 1 { if frame_len >= 1 {
self.state = State::ReadingData { frame_len }; self.state = State::ReadingData { frame_len };
@ -154,33 +151,22 @@ where
} }
} else if self.internal_buffer_pos >= 2 { } else if self.internal_buffer_pos >= 2 {
// Length prefix is too long. See module doc for info about max frame len. // Length prefix is too long. See module doc for info about max frame len.
return Err(IoError::new( return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long"));
IoErrorKind::InvalidData,
"frame length too long",
));
} else { } else {
// Prepare for next read. // Prepare for next read.
self.internal_buffer.push(0); self.internal_buffer.push(0);
} }
} }
State::ReadingData { frame_len } => { State::ReadingData { frame_len } => {
match self.inner let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
.get_mut() match self.inner.get_mut().read(slice) {
.read(&mut self.internal_buffer[self.internal_buffer_pos..]) Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
{
Ok(0) => {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
}
Ok(n) => self.internal_buffer_pos += n, Ok(n) => self.internal_buffer_pos += n,
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady); return Ok(Async::NotReady)
}
Err(err) => {
return Err(err);
} }
Err(err) => return Err(err)
}; };
if self.internal_buffer_pos >= frame_len as usize { if self.internal_buffer_pos >= frame_len as usize {
// Finished reading the frame of data. // Finished reading the frame of data.
self.state = State::ReadingLength; self.state = State::ReadingLength;
@ -196,12 +182,13 @@ where
} }
} }
impl<I, S> Sink for LengthDelimited<I, S> impl<R, C> Sink for LengthDelimited<R, C>
where where
S: AsyncWrite R: AsyncWrite,
C: Encoder
{ {
type SinkItem = <FramedWrite<S, UviBytes> as Sink>::SinkItem; type SinkItem = <FramedWrite<R, C> as Sink>::SinkItem;
type SinkError = <FramedWrite<S, UviBytes> as Sink>::SinkError; type SinkError = <FramedWrite<R, C> as Sink>::SinkError;
#[inline] #[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
@ -219,33 +206,17 @@ where
} }
} }
fn decode_length_prefix(buf: &[u8]) -> u16 {
debug_assert!(buf.len() <= 2);
let mut sum = 0u16;
for &byte in buf.iter().rev() {
let byte = byte & 0x7f;
sum <<= 7;
debug_assert!(sum.checked_add(u16::from(byte)).is_some());
sum += u16::from(byte);
}
sum
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use futures::{Future, Stream}; use futures::{Future, Stream};
use crate::length_delimited::LengthDelimited; use crate::length_delimited::LengthDelimited;
use std::io::Cursor; use std::io::{Cursor, ErrorKind};
use std::io::ErrorKind; use unsigned_varint::codec::UviBytes;
#[test] #[test]
fn basic_read() { fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4]; let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap(); let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
} }
@ -253,8 +224,7 @@ mod tests {
#[test] #[test]
fn basic_read_two() { fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7]; let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap(); let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
} }
@ -266,8 +236,7 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>(); let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8]; let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter()); data.extend(frame.clone().into_iter());
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed let recved = framed
.into_future() .into_future()
.map(|(m, _)| m) .map(|(m, _)| m)
@ -281,24 +250,24 @@ mod tests {
fn packet_len_too_long() { fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1]; let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0)); data.extend((0..16513).map(|_| 0));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed let recved = framed
.into_future() .into_future()
.map(|(m, _)| m) .map(|(m, _)| m)
.map_err(|(err, _)| err) .map_err(|(err, _)| err)
.wait(); .wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::InvalidData), if let Err(io_err) = recved {
_ => panic!(), assert_eq!(io_err.kind(), ErrorKind::InvalidData)
} else {
panic!()
} }
} }
#[test] #[test]
fn empty_frames() { fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7]; let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap(); let recved = framed.collect().wait().unwrap();
assert_eq!( assert_eq!(
recved, recved,
@ -315,36 +284,36 @@ mod tests {
#[test] #[test]
fn unexpected_eof_in_len() { fn unexpected_eof_in_len() {
let data = vec![0x89]; let data = vec![0x89];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait(); let recved = framed.collect().wait();
match recved { if let Err(io_err) = recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe), assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
_ => panic!(), } else {
panic!()
} }
} }
#[test] #[test]
fn unexpected_eof_in_data() { fn unexpected_eof_in_data() {
let data = vec![5]; let data = vec![5];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait(); let recved = framed.collect().wait();
match recved { if let Err(io_err) = recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe), assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
_ => panic!(), } else {
panic!()
} }
} }
#[test] #[test]
fn unexpected_eof_in_data2() { fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7]; let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait(); let recved = framed.collect().wait();
match recved { if let Err(io_err) = recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe), assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
_ => panic!(), } else {
panic!()
} }
} }
} }

View File

@ -21,9 +21,13 @@
//! Contains the `listener_select_proto` code, which allows selecting a protocol thanks to //! Contains the `listener_select_proto` code, which allows selecting a protocol thanks to
//! `multistream-select` for the listener. //! `multistream-select` for the listener.
use bytes::Bytes;
use futures::{prelude::*, sink, stream::StreamFuture}; use futures::{prelude::*, sink, stream::StreamFuture};
use crate::protocol::{DialerToListenerMessage, Listener, ListenerFuture, ListenerToDialerMessage}; use crate::protocol::{
DialerToListenerMessage,
Listener,
ListenerFuture,
ListenerToDialerMessage
};
use log::{debug, trace}; use log::{debug, trace};
use std::mem; use std::mem;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
@ -58,27 +62,31 @@ where
} }
/// Future, returned by `listener_select_proto` which selects a protocol among the ones supported. /// Future, returned by `listener_select_proto` which selects a protocol among the ones supported.
pub struct ListenerSelectFuture<R: AsyncRead + AsyncWrite, I, X> pub struct ListenerSelectFuture<R, I, X>
where where
for<'a> &'a I: IntoIterator<Item = X> R: AsyncRead + AsyncWrite,
for<'a> &'a I: IntoIterator<Item = X>,
X: AsRef<[u8]>
{ {
inner: ListenerSelectState<R, I, X> inner: ListenerSelectState<R, I, X>
} }
enum ListenerSelectState<R: AsyncRead + AsyncWrite, I, X> enum ListenerSelectState<R, I, X>
where where
for<'a> &'a I: IntoIterator<Item = X> R: AsyncRead + AsyncWrite,
for<'a> &'a I: IntoIterator<Item = X>,
X: AsRef<[u8]>
{ {
AwaitListener { AwaitListener {
listener_fut: ListenerFuture<R>, listener_fut: ListenerFuture<R, X>,
protocols: I protocols: I
}, },
Incoming { Incoming {
stream: StreamFuture<Listener<R>>, stream: StreamFuture<Listener<R, X>>,
protocols: I protocols: I
}, },
Outgoing { Outgoing {
sender: sink::Send<Listener<R>>, sender: sink::Send<Listener<R, X>>,
protocols: I, protocols: I,
outcome: Option<X> outcome: Option<X>
}, },
@ -87,9 +95,9 @@ where
impl<R, I, X> Future for ListenerSelectFuture<R, I, X> impl<R, I, X> Future for ListenerSelectFuture<R, I, X>
where where
for<'a> &'a I: IntoIterator<Item = X>,
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
X: AsRef<[u8]> for<'a> &'a I: IntoIterator<Item = X>,
X: AsRef<[u8]> + Clone
{ {
type Item = (X, R, I); type Item = (X, R, I);
type Error = ProtocolChoiceError; type Error = ProtocolChoiceError;
@ -119,10 +127,12 @@ where
}; };
match msg { match msg {
Some(DialerToListenerMessage::ProtocolsListRequest) => { Some(DialerToListenerMessage::ProtocolsListRequest) => {
let msg = ListenerToDialerMessage::ProtocolsListResponse { trace!("protocols list response: {:?}", protocols
list: protocols.into_iter().map(|x| Bytes::from(x.as_ref())).collect(), .into_iter()
}; .map(|p| p.as_ref().into())
trace!("protocols list response: {:?}", msg); .collect::<Vec<Vec<u8>>>());
let list = protocols.into_iter().collect();
let msg = ListenerToDialerMessage::ProtocolsListResponse { list };
let sender = listener.send(msg); let sender = listener.send(msg);
self.inner = ListenerSelectState::Outgoing { self.inner = ListenerSelectState::Outgoing {
sender, sender,
@ -135,12 +145,14 @@ where
let mut send_back = ListenerToDialerMessage::NotAvailable; let mut send_back = ListenerToDialerMessage::NotAvailable;
for supported in &protocols { for supported in &protocols {
if name.as_ref() == supported.as_ref() { if name.as_ref() == supported.as_ref() {
send_back = ListenerToDialerMessage::ProtocolAck {name: name.clone()}; send_back = ListenerToDialerMessage::ProtocolAck {
name: supported.clone()
};
outcome = Some(supported); outcome = Some(supported);
break; break;
} }
} }
trace!("requested: {:?}, response: {:?}", name, send_back); trace!("requested: {:?}, supported: {}", name, outcome.is_some());
let sender = listener.send(send_back); let sender = listener.send(send_back);
self.inner = ListenerSelectState::Outgoing { sender, protocols, outcome } self.inner = ListenerSelectState::Outgoing { sender, protocols, outcome }
} }

View File

@ -20,34 +20,35 @@
//! Contains the `Dialer` wrapper, which allows raw communications with a listener. //! Contains the `Dialer` wrapper, which allows raw communications with a listener.
use bytes::Bytes; use bytes::{BufMut, Bytes, BytesMut};
use futures::{prelude::*, sink, Async, AsyncSink, StartSend, try_ready};
use crate::length_delimited::LengthDelimited; use crate::length_delimited::LengthDelimited;
use crate::protocol::DialerToListenerMessage; use crate::protocol::DialerToListenerMessage;
use crate::protocol::ListenerToDialerMessage; use crate::protocol::ListenerToDialerMessage;
use crate::protocol::MultistreamSelectError; use crate::protocol::MultistreamSelectError;
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF; use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use futures::{prelude::*, sink, Async, StartSend, try_ready};
use std::io;
use tokio_codec::Encoder;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::decode; use unsigned_varint::{decode, codec::Uvi};
/// Wraps around a `AsyncRead+AsyncWrite`.
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the dialer's side. Produces and /// Assumes that we're on the dialer's side. Produces and accepts messages.
/// accepts messages. pub struct Dialer<R, N> {
pub struct Dialer<R> { inner: LengthDelimited<R, MessageEncoder<N>>,
inner: LengthDelimited<Bytes, R>, handshake_finished: bool
handshake_finished: bool,
} }
impl<R> Dialer<R> impl<R, N> Dialer<R, N>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{ {
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the pub fn new(inner: R) -> DialerFuture<R, N> {
/// future returns a `Dialer`. let codec = MessageEncoder(std::marker::PhantomData);
pub fn new(inner: R) -> DialerFuture<R> { let sender = LengthDelimited::new(inner, codec);
let sender = LengthDelimited::new(inner);
DialerFuture { DialerFuture {
inner: sender.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF)) inner: sender.send(Message::Header)
} }
} }
@ -58,43 +59,20 @@ where
} }
} }
impl<R> Sink for Dialer<R> impl<R, N> Sink for Dialer<R, N>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{ {
type SinkItem = DialerToListenerMessage; type SinkItem = DialerToListenerMessage<N>;
type SinkError = MultistreamSelectError; type SinkError = MultistreamSelectError;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
match item { match self.inner.start_send(Message::Body(item))? {
DialerToListenerMessage::ProtocolRequest { name } => { AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
if !name.starts_with(b"/") { AsyncSink::NotReady(Message::Header) => unreachable!(),
return Err(MultistreamSelectError::WrongProtocolName); AsyncSink::Ready => Ok(AsyncSink::Ready)
}
let mut protocol = Bytes::from(name);
protocol.extend_from_slice(&[b'\n']);
match self.inner.start_send(protocol) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(mut protocol)) => {
let protocol_len = protocol.len();
protocol.truncate(protocol_len - 1);
Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolRequest { name: protocol },
))
}
Err(err) => Err(err.into()),
}
}
DialerToListenerMessage::ProtocolsListRequest => {
match self.inner.start_send(Bytes::from(&b"ls\n"[..])) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolsListRequest,
)),
Err(err) => Err(err.into()),
}
}
} }
} }
@ -109,11 +87,11 @@ where
} }
} }
impl<R> Stream for Dialer<R> impl<R, N> Stream for Dialer<R, N>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite
{ {
type Item = ListenerToDialerMessage; type Item = ListenerToDialerMessage<Bytes>;
type Error = MultistreamSelectError; type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
@ -138,7 +116,7 @@ where
let frame_len = frame.len(); let frame_len = frame.len();
let protocol = frame.split_to(frame_len - 1); let protocol = frame.split_to(frame_len - 1);
return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck { return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck {
name: protocol, name: protocol
}))); })));
} else if frame == b"na\n"[..] { } else if frame == b"na\n"[..] {
return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable))); return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable)));
@ -166,12 +144,12 @@ where
} }
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`. /// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
pub struct DialerFuture<T: AsyncWrite> { pub struct DialerFuture<T: AsyncWrite, N: AsRef<[u8]>> {
inner: sink::Send<LengthDelimited<Bytes, T>> inner: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
} }
impl<T: AsyncWrite> Future for DialerFuture<T> { impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
type Item = Dialer<T>; type Item = Dialer<T, N>;
type Error = MultistreamSelectError; type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
@ -180,15 +158,57 @@ impl<T: AsyncWrite> Future for DialerFuture<T> {
} }
} }
/// tokio-codec `Encoder` handling `DialerToListenerMessage` values.
struct MessageEncoder<N>(std::marker::PhantomData<N>);
enum Message<N> {
Header,
Body(DialerToListenerMessage<N>)
}
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
type Item = Message<N>;
type Error = MultistreamSelectError;
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Header => {
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
Ok(())
}
Message::Body(DialerToListenerMessage::ProtocolRequest { name }) => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
if len > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
Uvi::<usize>::default().encode(len, dest)?;
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Message::Body(DialerToListenerMessage::ProtocolsListRequest) => {
Uvi::<usize>::default().encode(3, dest)?;
dest.reserve(3);
dest.put(&b"ls\n"[..]);
Ok(())
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
use tokio::runtime::current_thread::Runtime; use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream}; use tokio_tcp::{TcpListener, TcpStream};
use bytes::Bytes;
use futures::Future; use futures::Future;
use futures::{Sink, Stream}; use futures::{Sink, Stream};
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
#[test] #[test]
fn wrong_proto_name() { fn wrong_proto_name() {
@ -205,7 +225,7 @@ mod tests {
.from_err() .from_err()
.and_then(move |stream| Dialer::new(stream)) .and_then(move |stream| Dialer::new(stream))
.and_then(move |dialer| { .and_then(move |dialer| {
let p = Bytes::from("invalid_name"); let p = b"invalid_name";
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p }) dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
}); });

View File

@ -20,33 +20,35 @@
//! Contains the `Listener` wrapper, which allows raw communications with a dialer. //! Contains the `Listener` wrapper, which allows raw communications with a dialer.
use bytes::Bytes; use bytes::{BufMut, Bytes, BytesMut};
use futures::{Async, AsyncSink, prelude::*, sink, stream::StreamFuture};
use crate::length_delimited::LengthDelimited; use crate::length_delimited::LengthDelimited;
use crate::protocol::DialerToListenerMessage; use crate::protocol::DialerToListenerMessage;
use crate::protocol::ListenerToDialerMessage; use crate::protocol::ListenerToDialerMessage;
use crate::protocol::MultistreamSelectError; use crate::protocol::MultistreamSelectError;
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF; use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use futures::{prelude::*, sink, stream::StreamFuture};
use log::{debug, trace}; use log::{debug, trace};
use std::mem; use std::{io, mem};
use tokio_codec::Encoder;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::encode; use unsigned_varint::{encode, codec::Uvi};
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
/// accepts messages. /// accepts messages.
pub struct Listener<R> { pub struct Listener<R, N> {
inner: LengthDelimited<Bytes, R> inner: LengthDelimited<R, MessageEncoder<N>>
} }
impl<R> Listener<R> impl<R, N> Listener<R, N>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{ {
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
/// future returns a `Listener`. /// future returns a `Listener`.
pub fn new(inner: R) -> ListenerFuture<R> { pub fn new(inner: R) -> ListenerFuture<R, N> {
let inner = LengthDelimited::new(inner); let codec = MessageEncoder(std::marker::PhantomData);
let inner = LengthDelimited::new(inner, codec);
ListenerFuture { ListenerFuture {
inner: ListenerFutureState::Await { inner: inner.into_future() } inner: ListenerFutureState::Await { inner: inner.into_future() }
} }
@ -60,66 +62,20 @@ where
} }
} }
impl<R> Sink for Listener<R> impl<R, N> Sink for Listener<R, N>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{ {
type SinkItem = ListenerToDialerMessage; type SinkItem = ListenerToDialerMessage<N>;
type SinkError = MultistreamSelectError; type SinkError = MultistreamSelectError;
#[inline] #[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
match item { match self.inner.start_send(Message::Body(item))? {
ListenerToDialerMessage::ProtocolAck { name } => { AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
if !name.starts_with(b"/") { AsyncSink::NotReady(Message::Header) => unreachable!(),
debug!("invalid protocol name {:?}", name); AsyncSink::Ready => Ok(AsyncSink::Ready)
return Err(MultistreamSelectError::WrongProtocolName);
}
let mut protocol = Bytes::from(name);
protocol.extend_from_slice(&[b'\n']);
match self.inner.start_send(protocol) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(mut protocol)) => {
let protocol_len = protocol.len();
protocol.truncate(protocol_len - 1);
Ok(AsyncSink::NotReady(ListenerToDialerMessage::ProtocolAck {
name: protocol,
}))
}
Err(err) => Err(err.into()),
}
}
ListenerToDialerMessage::NotAvailable => {
match self.inner.start_send(Bytes::from(&b"na\n"[..])) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => {
Ok(AsyncSink::NotReady(ListenerToDialerMessage::NotAvailable))
}
Err(err) => Err(err.into()),
}
}
ListenerToDialerMessage::ProtocolsListResponse { list } => {
use std::iter;
let mut buf = encode::usize_buffer();
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
for elem in &list {
out_msg.extend(encode::usize(elem.len() + 1, &mut buf)); // +1 for '\n'
out_msg.extend_from_slice(elem);
out_msg.extend(iter::once(b'\n'));
}
match self.inner.start_send(Bytes::from(out_msg)) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => {
let m = ListenerToDialerMessage::ProtocolsListResponse { list };
Ok(AsyncSink::NotReady(m))
}
Err(err) => Err(err.into()),
}
}
} }
} }
@ -134,11 +90,11 @@ where
} }
} }
impl<R> Stream for Listener<R> impl<R, N> Stream for Listener<R, N>
where where
R: AsyncRead + AsyncWrite, R: AsyncRead + AsyncWrite,
{ {
type Item = DialerToListenerMessage; type Item = DialerToListenerMessage<Bytes>;
type Error = MultistreamSelectError; type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
@ -168,22 +124,22 @@ where
/// Future, returned by `Listener::new` which performs the handshake and returns /// Future, returned by `Listener::new` which performs the handshake and returns
/// the `Listener` if successful. /// the `Listener` if successful.
pub struct ListenerFuture<T: AsyncRead + AsyncWrite> { pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
inner: ListenerFutureState<T> inner: ListenerFutureState<T, N>
} }
enum ListenerFutureState<T: AsyncRead + AsyncWrite> { enum ListenerFutureState<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
Await { Await {
inner: StreamFuture<LengthDelimited<Bytes, T>> inner: StreamFuture<LengthDelimited<T, MessageEncoder<N>>>
}, },
Reply { Reply {
sender: sink::Send<LengthDelimited<Bytes, T>> sender: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
}, },
Undefined Undefined
} }
impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> { impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N> {
type Item = Listener<T>; type Item = Listener<T, N>;
type Error = MultistreamSelectError; type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
@ -204,7 +160,7 @@ impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
return Err(MultistreamSelectError::FailedHandshake) return Err(MultistreamSelectError::FailedHandshake)
} }
trace!("sending back /multistream/<version> to finish the handshake"); trace!("sending back /multistream/<version> to finish the handshake");
let sender = socket.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF)); let sender = socket.send(Message::Header);
self.inner = ListenerFutureState::Reply { sender } self.inner = ListenerFutureState::Reply { sender }
} }
ListenerFutureState::Reply { mut sender } => { ListenerFutureState::Reply { mut sender } => {
@ -223,6 +179,66 @@ impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
} }
} }
/// tokio-codec `Encoder` handling `ListenerToDialerMessage` values.
struct MessageEncoder<N>(std::marker::PhantomData<N>);
enum Message<N> {
Header,
Body(ListenerToDialerMessage<N>)
}
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
type Item = Message<N>;
type Error = MultistreamSelectError;
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Header => {
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
Ok(())
}
Message::Body(ListenerToDialerMessage::ProtocolAck { name }) => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
if len > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
Uvi::<usize>::default().encode(len, dest)?;
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Message::Body(ListenerToDialerMessage::ProtocolsListResponse { list }) => {
let mut buf = encode::usize_buffer();
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
for e in &list {
if e.as_ref().len() + 1 > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
out_msg.extend(encode::usize(e.as_ref().len() + 1, &mut buf)); // +1 for '\n'
out_msg.extend_from_slice(e.as_ref());
out_msg.push(b'\n')
}
let len = encode::usize(out_msg.len(), &mut buf);
dest.reserve(len.len() + out_msg.len());
dest.put(len);
dest.put(out_msg);
Ok(())
}
Message::Body(ListenerToDialerMessage::NotAvailable) => {
Uvi::<usize>::default().encode(3, dest)?;
dest.reserve(3);
dest.put(&b"na\n"[..]);
Ok(())
}
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -250,7 +266,7 @@ mod tests {
let client = TcpStream::connect(&listener_addr) let client = TcpStream::connect(&listener_addr)
.from_err() .from_err()
.and_then(move |stream| Dialer::new(stream)); .and_then(move |stream| Dialer::<_, Bytes>::new(stream));
let mut rt = Runtime::new().unwrap(); let mut rt = Runtime::new().unwrap();
match rt.block_on(server.join(client)) { match rt.block_on(server.join(client)) {

View File

@ -20,8 +20,6 @@
//! Contains lower-level structs to handle the multistream protocol. //! Contains lower-level structs to handle the multistream protocol.
use bytes::Bytes;
mod dialer; mod dialer;
mod error; mod error;
mod listener; mod listener;
@ -34,14 +32,14 @@ pub use self::listener::{Listener, ListenerFuture};
/// Message sent from the dialer to the listener. /// Message sent from the dialer to the listener.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum DialerToListenerMessage { pub enum DialerToListenerMessage<N> {
/// The dialer wants us to use a protocol. /// The dialer wants us to use a protocol.
/// ///
/// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start /// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start
/// communicating in the new protocol. /// communicating in the new protocol.
ProtocolRequest { ProtocolRequest {
/// Name of the protocol. /// Name of the protocol.
name: Bytes, name: N
}, },
/// The dialer requested the list of protocols that the listener supports. /// The dialer requested the list of protocols that the listener supports.
@ -50,10 +48,10 @@ pub enum DialerToListenerMessage {
/// Message sent from the listener to the dialer. /// Message sent from the listener to the dialer.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub enum ListenerToDialerMessage { pub enum ListenerToDialerMessage<N> {
/// The protocol requested by the dialer is accepted. The socket immediately starts using the /// The protocol requested by the dialer is accepted. The socket immediately starts using the
/// new protocol. /// new protocol.
ProtocolAck { name: Bytes }, ProtocolAck { name: N },
/// The protocol requested by the dialer is not supported or available. /// The protocol requested by the dialer is not supported or available.
NotAvailable, NotAvailable,
@ -62,6 +60,7 @@ pub enum ListenerToDialerMessage {
ProtocolsListResponse { ProtocolsListResponse {
/// The list of protocols. /// The list of protocols.
// TODO: use some sort of iterator // TODO: use some sort of iterator
list: Vec<Bytes>, list: Vec<N>,
}, },
} }

View File

@ -22,15 +22,13 @@
#![cfg(test)] #![cfg(test)]
use crate::ProtocolChoiceError;
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage};
use crate::{dialer_select_proto, listener_select_proto};
use futures::prelude::*;
use tokio::runtime::current_thread::Runtime; use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream}; use tokio_tcp::{TcpListener, TcpStream};
use bytes::Bytes;
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
use futures::Future;
use futures::{Sink, Stream};
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage};
use crate::ProtocolChoiceError;
use crate::{dialer_select_proto, listener_select_proto};
/// Holds a `Vec` and satifies the iterator requirements of `listener_select_proto`. /// Holds a `Vec` and satifies the iterator requirements of `listener_select_proto`.
struct VecRefIntoIter<T>(Vec<T>); struct VecRefIntoIter<T>(Vec<T>);
@ -68,7 +66,7 @@ fn negotiate_with_self_succeeds() {
.from_err() .from_err()
.and_then(move |stream| Dialer::new(stream)) .and_then(move |stream| Dialer::new(stream))
.and_then(move |dialer| { .and_then(move |dialer| {
let p = Bytes::from("/hello/1.0.0"); let p = b"/hello/1.0.0";
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p }) dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
}) })
.and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e)) .and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e))