multistream-select: Less allocations. (#800)

This commit is contained in:
Toralf Wittner
2019-01-09 15:09:35 +01:00
committed by GitHub
parent aedf9c0c31
commit f1959252b7
9 changed files with 467 additions and 372 deletions

View File

@ -224,6 +224,7 @@ type NameWrapIter<I> =
std::iter::Map<I, fn(<I as Iterator>::Item) -> NameWrap<<I as Iterator>::Item>>;
/// Wrapper type to expose an `AsRef<[u8]>` impl for all types implementing `ProtocolName`.
#[derive(Clone)]
struct NameWrap<N>(N);
impl<N: ProtocolName> AsRef<[u8]> for NameWrap<N> {

View File

@ -91,7 +91,7 @@ impl<T: AsRef<[u8]>> ProtocolName for T {
/// or both.
pub trait UpgradeInfo {
/// Opaque type representing a negotiable protocol.
type Info: ProtocolName;
type Info: ProtocolName + Clone;
/// Iterator returned by `protocol_info`.
type InfoIter: IntoIterator<Item = Self::Info>;

View File

@ -21,9 +21,13 @@
//! Contains the `dialer_select_proto` code, which allows selecting a protocol thanks to
//! `multistream-select` for the dialer.
use bytes::Bytes;
use futures::{future::Either, prelude::*, sink, stream::StreamFuture};
use crate::protocol::{Dialer, DialerFuture, DialerToListenerMessage, ListenerToDialerMessage};
use futures::{future::Either, prelude::*, stream::StreamFuture};
use crate::protocol::{
Dialer,
DialerFuture,
DialerToListenerMessage,
ListenerToDialerMessage
};
use log::trace;
use std::mem;
use tokio_io::{AsyncRead, AsyncWrite};
@ -44,7 +48,6 @@ pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPa
/// remote, and the protocol name that we passed (so that you don't have to clone the name). On
/// success, the function returns the identifier (of type `P`), plus the socket which now uses that
/// chosen protocol.
#[inline]
pub fn dialer_select_proto<R, I>(inner: R, protocols: I) -> DialerSelectFuture<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
@ -64,12 +67,13 @@ where
///
/// Same as `dialer_select_proto`. Tries protocols one by one. The iterator doesn't need to produce
/// match functions, because it's not needed.
pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I,) -> DialerSelectSeq<R, I>
pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I) -> DialerSelectSeq<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I: IntoIterator,
I::Item: AsRef<[u8]>
{
let protocols = protocols.into_iter();
DialerSelectSeq {
inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
}
@ -78,26 +82,37 @@ where
/// Future, returned by `dialer_select_proto_serial` which selects a protocol
/// and dialer sequentially.
pub struct DialerSelectSeq<R: AsyncRead + AsyncWrite, I: Iterator> {
pub struct DialerSelectSeq<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
inner: DialerSelectSeqState<R, I>
}
enum DialerSelectSeqState<R: AsyncRead + AsyncWrite, I: Iterator> {
enum DialerSelectSeqState<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
AwaitDialer {
dialer_fut: DialerFuture<R>,
dialer_fut: DialerFuture<R, I::Item>,
protocols: I
},
NextProtocol {
dialer: Dialer<R>,
dialer: Dialer<R, I::Item>,
proto_name: I::Item,
protocols: I
},
SendProtocol {
sender: sink::Send<Dialer<R>>,
FlushProtocol {
dialer: Dialer<R, I::Item>,
proto_name: I::Item,
protocols: I
},
AwaitProtocol {
stream: StreamFuture<Dialer<R>>,
stream: StreamFuture<Dialer<R, I::Item>>,
proto_name: I::Item,
protocols: I
},
@ -106,9 +121,9 @@ enum DialerSelectSeqState<R: AsyncRead + AsyncWrite, I: Iterator> {
impl<R, I> Future for DialerSelectSeq<R, I>
where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]> + Clone
{
type Item = (I::Item, R);
type Error = ProtocolChoiceError;
@ -116,7 +131,7 @@ where
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.inner, DialerSelectSeqState::Undefined) {
DialerSelectSeqState::AwaitDialer { mut dialer_fut, protocols } => {
DialerSelectSeqState::AwaitDialer { mut dialer_fut, mut protocols } => {
let dialer = match dialer_fut.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
@ -124,42 +139,57 @@ where
return Ok(Async::NotReady)
}
};
self.inner = DialerSelectSeqState::NextProtocol { dialer, protocols }
let proto_name = protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?;
self.inner = DialerSelectSeqState::NextProtocol {
dialer,
protocols,
proto_name
}
DialerSelectSeqState::NextProtocol { dialer, mut protocols } => {
let proto_name =
protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?;
}
DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => {
trace!("sending {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest {
name: Bytes::from(proto_name.as_ref())
name: proto_name.clone()
};
trace!("sending {:?}", req);
let sender = dialer.send(req);
self.inner = DialerSelectSeqState::SendProtocol {
sender,
match dialer.start_send(req)? {
AsyncSink::Ready => {
self.inner = DialerSelectSeqState::FlushProtocol {
dialer,
proto_name,
protocols
}
}
DialerSelectSeqState::SendProtocol { mut sender, proto_name, protocols } => {
let dialer = match sender.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectSeqState::SendProtocol {
sender,
proto_name,
protocols
AsyncSink::NotReady(_) => {
self.inner = DialerSelectSeqState::NextProtocol {
dialer,
protocols,
proto_name
};
return Ok(Async::NotReady)
}
};
}
}
DialerSelectSeqState::FlushProtocol { mut dialer, proto_name, protocols } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
let stream = dialer.into_future();
self.inner = DialerSelectSeqState::AwaitProtocol {
stream,
proto_name,
protocols
};
}
DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, protocols } => {
}
Async::NotReady => {
self.inner = DialerSelectSeqState::FlushProtocol {
dialer,
proto_name,
protocols
};
return Ok(Async::NotReady)
}
}
}
DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, mut protocols } => {
let (m, r) = match stream.poll() {
Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => {
@ -178,9 +208,15 @@ where
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, r.into_inner())))
},
}
ListenerToDialerMessage::NotAvailable => {
self.inner = DialerSelectSeqState::NextProtocol { dialer: r, protocols }
let proto_name = protocols.next()
.ok_or(ProtocolChoiceError::NoProtocolFound)?;
self.inner = DialerSelectSeqState::NextProtocol {
dialer: r,
protocols,
proto_name
}
}
_ => return Err(ProtocolChoiceError::UnexpectedMessage)
}
@ -192,17 +228,17 @@ where
}
}
/// Helps selecting a protocol amongst the ones supported.
///
/// Same as `dialer_select_proto`. Queries the list of supported protocols from the remote, then
/// chooses the most appropriate one.
pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I>
pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I::IntoIter>
where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
{
let protocols = protocols.into_iter();
DialerSelectPar {
inner: DialerSelectParState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
}
@ -212,29 +248,47 @@ where
/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in
/// parellel, by first requesting the liste of protocols supported by the remote endpoint and
/// then selecting the most appropriate one by applying a match predicate to the result.
pub struct DialerSelectPar<R: AsyncRead + AsyncWrite, I: Iterator> {
pub struct DialerSelectPar<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
inner: DialerSelectParState<R, I>
}
enum DialerSelectParState<R: AsyncRead + AsyncWrite, I: Iterator> {
enum DialerSelectParState<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
AwaitDialer {
dialer_fut: DialerFuture<R>,
dialer_fut: DialerFuture<R, I::Item>,
protocols: I
},
SendRequest {
sender: sink::Send<Dialer<R>>,
ProtocolList {
dialer: Dialer<R, I::Item>,
protocols: I
},
AwaitResponse {
stream: StreamFuture<Dialer<R>>,
FlushListRequest {
dialer: Dialer<R, I::Item>,
protocols: I
},
SendProtocol {
sender: sink::Send<Dialer<R>>,
AwaitListResponse {
stream: StreamFuture<Dialer<R, I::Item>>,
protocols: I,
},
Protocol {
dialer: Dialer<R, I::Item>,
proto_name: I::Item
},
FlushProtocol {
dialer: Dialer<R, I::Item>,
proto_name: I::Item
},
AwaitProtocol {
stream: StreamFuture<Dialer<R>>,
stream: StreamFuture<Dialer<R, I::Item>>,
proto_name: I::Item
},
Undefined
@ -242,9 +296,9 @@ enum DialerSelectParState<R: AsyncRead + AsyncWrite, I: Iterator> {
impl<R, I> Future for DialerSelectPar<R, I>
where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]> + Clone
{
type Item = (I::Item, R);
type Error = ProtocolChoiceError;
@ -253,41 +307,63 @@ where
loop {
match mem::replace(&mut self.inner, DialerSelectParState::Undefined) {
DialerSelectParState::AwaitDialer { mut dialer_fut, protocols } => {
let dialer = match dialer_fut.poll()? {
Async::Ready(d) => d,
match dialer_fut.poll()? {
Async::Ready(dialer) => {
self.inner = DialerSelectParState::ProtocolList { dialer, protocols }
}
Async::NotReady => {
self.inner = DialerSelectParState::AwaitDialer { dialer_fut, protocols };
return Ok(Async::NotReady)
}
};
trace!("requesting protocols list");
let sender = dialer.send(DialerToListenerMessage::ProtocolsListRequest);
self.inner = DialerSelectParState::SendRequest { sender, protocols };
}
DialerSelectParState::SendRequest { mut sender, protocols } => {
let dialer = match sender.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectParState::SendRequest { sender, protocols };
}
DialerSelectParState::ProtocolList { mut dialer, protocols } => {
trace!("requesting protocols list");
match dialer.start_send(DialerToListenerMessage::ProtocolsListRequest)? {
AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushListRequest {
dialer,
protocols
}
}
AsyncSink::NotReady(_) => {
self.inner = DialerSelectParState::ProtocolList { dialer, protocols };
return Ok(Async::NotReady)
}
};
let stream = dialer.into_future();
self.inner = DialerSelectParState::AwaitResponse { stream, protocols };
}
DialerSelectParState::AwaitResponse { mut stream, protocols } => {
let (m, d) = match stream.poll() {
}
DialerSelectParState::FlushListRequest { mut dialer, protocols } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
self.inner = DialerSelectParState::AwaitListResponse {
stream: dialer.into_future(),
protocols
}
}
Async::NotReady => {
self.inner = DialerSelectParState::FlushListRequest {
dialer,
protocols
};
return Ok(Async::NotReady)
}
}
}
DialerSelectParState::AwaitListResponse { mut stream, protocols } => {
let (resp, dialer) = match stream.poll() {
Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => {
self.inner = DialerSelectParState::AwaitResponse { stream, protocols };
self.inner = DialerSelectParState::AwaitListResponse { stream, protocols };
return Ok(Async::NotReady)
}
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
};
trace!("protocols list response: {:?}", m);
let list = match m {
Some(ListenerToDialerMessage::ProtocolsListResponse { list }) => list,
_ => return Err(ProtocolChoiceError::UnexpectedMessage),
trace!("protocols list response: {:?}", resp);
let list =
if let Some(ListenerToDialerMessage::ProtocolsListResponse { list }) = resp {
list
} else {
return Err(ProtocolChoiceError::UnexpectedMessage)
};
let mut found = None;
for local_name in protocols {
@ -302,47 +378,52 @@ where
}
}
let proto_name = found.ok_or(ProtocolChoiceError::NoProtocolFound)?;
trace!("sending {:?}", proto_name.as_ref());
let sender = d.send(DialerToListenerMessage::ProtocolRequest {
name: Bytes::from(proto_name.as_ref())
});
self.inner = DialerSelectParState::SendProtocol { sender, proto_name };
self.inner = DialerSelectParState::Protocol { dialer, proto_name }
}
DialerSelectParState::SendProtocol { mut sender, proto_name } => {
let dialer = match sender.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectParState::SendProtocol {
sender,
proto_name
DialerSelectParState::Protocol { mut dialer, proto_name } => {
trace!("requesting protocol: {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest {
name: proto_name.clone()
};
match dialer.start_send(req)? {
AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }
}
AsyncSink::NotReady(_) => {
self.inner = DialerSelectParState::Protocol { dialer, proto_name };
return Ok(Async::NotReady)
}
};
let stream = dialer.into_future();
}
}
DialerSelectParState::FlushProtocol { mut dialer, proto_name } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
self.inner = DialerSelectParState::AwaitProtocol {
stream,
stream: dialer.into_future(),
proto_name
};
}
}
Async::NotReady => {
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name };
return Ok(Async::NotReady)
}
}
}
DialerSelectParState::AwaitProtocol { mut stream, proto_name } => {
let (m, r) = match stream.poll() {
let (resp, dialer) = match stream.poll() {
Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => {
self.inner = DialerSelectParState::AwaitProtocol {
stream,
proto_name
};
self.inner = DialerSelectParState::AwaitProtocol { stream, proto_name };
return Ok(Async::NotReady)
}
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
};
trace!("received {:?}", m);
match m {
trace!("received {:?}", resp);
match resp {
Some(ListenerToDialerMessage::ProtocolAck { ref name })
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, r.into_inner())))
return Ok(Async::Ready((proto_name, dialer.into_inner())))
}
_ => return Err(ProtocolChoiceError::UnexpectedMessage)
}
@ -354,4 +435,3 @@ where
}
}

View File

@ -18,21 +18,22 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use bytes::Bytes;
use futures::{Async, Poll, Sink, StartSend, Stream};
use smallvec::SmallVec;
use std::{io::{Error as IoError, ErrorKind as IoErrorKind}, marker::PhantomData, u16};
use tokio_codec::FramedWrite;
use std::{io, u16};
use tokio_codec::{Encoder, FramedWrite};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec::UviBytes;
use unsigned_varint::decode;
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
/// and write unsigned-varint prefixed frames.
///
/// We purposely only support a frame length of under 64kiB. Frames mostly consist
/// in a short protocol name, which is highly unlikely to be more than 64kiB long.
pub struct LengthDelimited<I, S> {
pub struct LengthDelimited<R, C> {
// The inner socket where data is pulled from.
inner: FramedWrite<S, UviBytes>,
inner: FramedWrite<R, C>,
// Intermediary buffer where we put either the length of the next frame of data, or the frame
// of data itself before it is returned.
// Must always contain enough space to read data from `inner`.
@ -40,8 +41,7 @@ pub struct LengthDelimited<I, S> {
// Number of bytes within `internal_buffer` that contain valid data.
internal_buffer_pos: usize,
// State of the decoder.
state: State,
marker: PhantomData<I>,
state: State
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
@ -52,24 +52,21 @@ enum State {
ReadingData { frame_len: u16 },
}
impl<I, S> LengthDelimited<I, S>
impl<R, C> LengthDelimited<R, C>
where
S: AsyncWrite
R: AsyncWrite,
C: Encoder
{
pub fn new(inner: S) -> LengthDelimited<I, S> {
let mut encoder = UviBytes::default();
encoder.set_max_len(usize::from(u16::MAX));
pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
LengthDelimited {
inner: FramedWrite::new(inner, encoder),
inner: FramedWrite::new(inner, codec),
internal_buffer: {
let mut v = SmallVec::new();
v.push(0);
v
},
internal_buffer_pos: 0,
state: State::ReadingLength,
marker: PhantomData,
state: State::ReadingLength
}
}
@ -85,20 +82,19 @@ where
/// the modifiers provided by the `futures` crate) will always leave the object in a state in
/// which `into_inner()` will not panic.
#[inline]
pub fn into_inner(self) -> S {
pub fn into_inner(self) -> R {
assert_eq!(self.state, State::ReadingLength);
assert_eq!(self.internal_buffer_pos, 0);
self.inner.into_inner()
}
}
impl<I, S> Stream for LengthDelimited<I, S>
impl<R, C> Stream for LengthDelimited<R, C>
where
S: AsyncRead,
I: for<'r> From<&'r [u8]>,
R: AsyncRead
{
type Item = I;
type Error = IoError;
type Item = Bytes;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
@ -107,23 +103,21 @@ where
match self.state {
State::ReadingLength => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
match self.inner.get_mut().read(slice) {
Ok(0) => {
// EOF
if self.internal_buffer_pos == 0 {
return Ok(Async::Ready(None));
} else {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
return Err(io::ErrorKind::UnexpectedEof.into());
}
}
Ok(n) => {
debug_assert_eq!(n, 1);
self.internal_buffer_pos += n;
}
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady);
}
Err(err) => {
@ -136,7 +130,10 @@ where
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 {
// End of length prefix. Most of the time we will switch to reading data,
// but we need to handle a few corner cases first.
let frame_len = decode_length_prefix(&self.internal_buffer);
let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| {
log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;
if frame_len >= 1 {
self.state = State::ReadingData { frame_len };
@ -154,33 +151,22 @@ where
}
} else if self.internal_buffer_pos >= 2 {
// Length prefix is too long. See module doc for info about max frame len.
return Err(IoError::new(
IoErrorKind::InvalidData,
"frame length too long",
));
return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long"));
} else {
// Prepare for next read.
self.internal_buffer.push(0);
}
}
State::ReadingData { frame_len } => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
Ok(0) => {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
}
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
match self.inner.get_mut().read(slice) {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
Ok(n) => self.internal_buffer_pos += n,
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
return Ok(Async::NotReady);
}
Err(err) => {
return Err(err);
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady)
}
Err(err) => return Err(err)
};
if self.internal_buffer_pos >= frame_len as usize {
// Finished reading the frame of data.
self.state = State::ReadingLength;
@ -196,12 +182,13 @@ where
}
}
impl<I, S> Sink for LengthDelimited<I, S>
impl<R, C> Sink for LengthDelimited<R, C>
where
S: AsyncWrite
R: AsyncWrite,
C: Encoder
{
type SinkItem = <FramedWrite<S, UviBytes> as Sink>::SinkItem;
type SinkError = <FramedWrite<S, UviBytes> as Sink>::SinkError;
type SinkItem = <FramedWrite<R, C> as Sink>::SinkItem;
type SinkError = <FramedWrite<R, C> as Sink>::SinkError;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
@ -219,33 +206,17 @@ where
}
}
fn decode_length_prefix(buf: &[u8]) -> u16 {
debug_assert!(buf.len() <= 2);
let mut sum = 0u16;
for &byte in buf.iter().rev() {
let byte = byte & 0x7f;
sum <<= 7;
debug_assert!(sum.checked_add(u16::from(byte)).is_some());
sum += u16::from(byte);
}
sum
}
#[cfg(test)]
mod tests {
use futures::{Future, Stream};
use crate::length_delimited::LengthDelimited;
use std::io::Cursor;
use std::io::ErrorKind;
use std::io::{Cursor, ErrorKind};
use unsigned_varint::codec::UviBytes;
#[test]
fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
}
@ -253,8 +224,7 @@ mod tests {
#[test]
fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
}
@ -266,8 +236,7 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter());
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed
.into_future()
.map(|(m, _)| m)
@ -281,24 +250,24 @@ mod tests {
fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed
.into_future()
.map(|(m, _)| m)
.map_err(|(err, _)| err)
.wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::InvalidData),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
} else {
panic!()
}
}
#[test]
fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap();
assert_eq!(
recved,
@ -315,36 +284,36 @@ mod tests {
#[test]
fn unexpected_eof_in_len() {
let data = vec![0x89];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}
#[test]
fn unexpected_eof_in_data() {
let data = vec![5];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}
#[test]
fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}
}

View File

@ -21,9 +21,13 @@
//! Contains the `listener_select_proto` code, which allows selecting a protocol thanks to
//! `multistream-select` for the listener.
use bytes::Bytes;
use futures::{prelude::*, sink, stream::StreamFuture};
use crate::protocol::{DialerToListenerMessage, Listener, ListenerFuture, ListenerToDialerMessage};
use crate::protocol::{
DialerToListenerMessage,
Listener,
ListenerFuture,
ListenerToDialerMessage
};
use log::{debug, trace};
use std::mem;
use tokio_io::{AsyncRead, AsyncWrite};
@ -58,27 +62,31 @@ where
}
/// Future, returned by `listener_select_proto` which selects a protocol among the ones supported.
pub struct ListenerSelectFuture<R: AsyncRead + AsyncWrite, I, X>
pub struct ListenerSelectFuture<R, I, X>
where
for<'a> &'a I: IntoIterator<Item = X>
R: AsyncRead + AsyncWrite,
for<'a> &'a I: IntoIterator<Item = X>,
X: AsRef<[u8]>
{
inner: ListenerSelectState<R, I, X>
}
enum ListenerSelectState<R: AsyncRead + AsyncWrite, I, X>
enum ListenerSelectState<R, I, X>
where
for<'a> &'a I: IntoIterator<Item = X>
R: AsyncRead + AsyncWrite,
for<'a> &'a I: IntoIterator<Item = X>,
X: AsRef<[u8]>
{
AwaitListener {
listener_fut: ListenerFuture<R>,
listener_fut: ListenerFuture<R, X>,
protocols: I
},
Incoming {
stream: StreamFuture<Listener<R>>,
stream: StreamFuture<Listener<R, X>>,
protocols: I
},
Outgoing {
sender: sink::Send<Listener<R>>,
sender: sink::Send<Listener<R, X>>,
protocols: I,
outcome: Option<X>
},
@ -87,9 +95,9 @@ where
impl<R, I, X> Future for ListenerSelectFuture<R, I, X>
where
for<'a> &'a I: IntoIterator<Item = X>,
R: AsyncRead + AsyncWrite,
X: AsRef<[u8]>
for<'a> &'a I: IntoIterator<Item = X>,
X: AsRef<[u8]> + Clone
{
type Item = (X, R, I);
type Error = ProtocolChoiceError;
@ -119,10 +127,12 @@ where
};
match msg {
Some(DialerToListenerMessage::ProtocolsListRequest) => {
let msg = ListenerToDialerMessage::ProtocolsListResponse {
list: protocols.into_iter().map(|x| Bytes::from(x.as_ref())).collect(),
};
trace!("protocols list response: {:?}", msg);
trace!("protocols list response: {:?}", protocols
.into_iter()
.map(|p| p.as_ref().into())
.collect::<Vec<Vec<u8>>>());
let list = protocols.into_iter().collect();
let msg = ListenerToDialerMessage::ProtocolsListResponse { list };
let sender = listener.send(msg);
self.inner = ListenerSelectState::Outgoing {
sender,
@ -135,12 +145,14 @@ where
let mut send_back = ListenerToDialerMessage::NotAvailable;
for supported in &protocols {
if name.as_ref() == supported.as_ref() {
send_back = ListenerToDialerMessage::ProtocolAck {name: name.clone()};
send_back = ListenerToDialerMessage::ProtocolAck {
name: supported.clone()
};
outcome = Some(supported);
break;
}
}
trace!("requested: {:?}, response: {:?}", name, send_back);
trace!("requested: {:?}, supported: {}", name, outcome.is_some());
let sender = listener.send(send_back);
self.inner = ListenerSelectState::Outgoing { sender, protocols, outcome }
}

View File

@ -20,34 +20,35 @@
//! Contains the `Dialer` wrapper, which allows raw communications with a listener.
use bytes::Bytes;
use futures::{prelude::*, sink, Async, AsyncSink, StartSend, try_ready};
use bytes::{BufMut, Bytes, BytesMut};
use crate::length_delimited::LengthDelimited;
use crate::protocol::DialerToListenerMessage;
use crate::protocol::ListenerToDialerMessage;
use crate::protocol::MultistreamSelectError;
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use futures::{prelude::*, sink, Async, StartSend, try_ready};
use std::io;
use tokio_codec::Encoder;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::decode;
use unsigned_varint::{decode, codec::Uvi};
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the dialer's side. Produces and
/// accepts messages.
pub struct Dialer<R> {
inner: LengthDelimited<Bytes, R>,
handshake_finished: bool,
/// Wraps around a `AsyncRead+AsyncWrite`.
/// Assumes that we're on the dialer's side. Produces and accepts messages.
pub struct Dialer<R, N> {
inner: LengthDelimited<R, MessageEncoder<N>>,
handshake_finished: bool
}
impl<R> Dialer<R>
impl<R, N> Dialer<R, N>
where
R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
/// future returns a `Dialer`.
pub fn new(inner: R) -> DialerFuture<R> {
let sender = LengthDelimited::new(inner);
pub fn new(inner: R) -> DialerFuture<R, N> {
let codec = MessageEncoder(std::marker::PhantomData);
let sender = LengthDelimited::new(inner, codec);
DialerFuture {
inner: sender.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF))
inner: sender.send(Message::Header)
}
}
@ -58,43 +59,20 @@ where
}
}
impl<R> Sink for Dialer<R>
impl<R, N> Sink for Dialer<R, N>
where
R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{
type SinkItem = DialerToListenerMessage;
type SinkItem = DialerToListenerMessage<N>;
type SinkError = MultistreamSelectError;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
match item {
DialerToListenerMessage::ProtocolRequest { name } => {
if !name.starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName);
}
let mut protocol = Bytes::from(name);
protocol.extend_from_slice(&[b'\n']);
match self.inner.start_send(protocol) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(mut protocol)) => {
let protocol_len = protocol.len();
protocol.truncate(protocol_len - 1);
Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolRequest { name: protocol },
))
}
Err(err) => Err(err.into()),
}
}
DialerToListenerMessage::ProtocolsListRequest => {
match self.inner.start_send(Bytes::from(&b"ls\n"[..])) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolsListRequest,
)),
Err(err) => Err(err.into()),
}
}
match self.inner.start_send(Message::Body(item))? {
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
AsyncSink::NotReady(Message::Header) => unreachable!(),
AsyncSink::Ready => Ok(AsyncSink::Ready)
}
}
@ -109,11 +87,11 @@ where
}
}
impl<R> Stream for Dialer<R>
impl<R, N> Stream for Dialer<R, N>
where
R: AsyncRead + AsyncWrite,
R: AsyncRead + AsyncWrite
{
type Item = ListenerToDialerMessage;
type Item = ListenerToDialerMessage<Bytes>;
type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
@ -138,7 +116,7 @@ where
let frame_len = frame.len();
let protocol = frame.split_to(frame_len - 1);
return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck {
name: protocol,
name: protocol
})));
} else if frame == b"na\n"[..] {
return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable)));
@ -166,12 +144,12 @@ where
}
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
pub struct DialerFuture<T: AsyncWrite> {
inner: sink::Send<LengthDelimited<Bytes, T>>
pub struct DialerFuture<T: AsyncWrite, N: AsRef<[u8]>> {
inner: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
}
impl<T: AsyncWrite> Future for DialerFuture<T> {
type Item = Dialer<T>;
impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
type Item = Dialer<T, N>;
type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
@ -180,15 +158,57 @@ impl<T: AsyncWrite> Future for DialerFuture<T> {
}
}
/// tokio-codec `Encoder` handling `DialerToListenerMessage` values.
struct MessageEncoder<N>(std::marker::PhantomData<N>);
enum Message<N> {
Header,
Body(DialerToListenerMessage<N>)
}
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
type Item = Message<N>;
type Error = MultistreamSelectError;
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Header => {
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
Ok(())
}
Message::Body(DialerToListenerMessage::ProtocolRequest { name }) => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
if len > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
Uvi::<usize>::default().encode(len, dest)?;
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Message::Body(DialerToListenerMessage::ProtocolsListRequest) => {
Uvi::<usize>::default().encode(3, dest)?;
dest.reserve(3);
dest.put(&b"ls\n"[..]);
Ok(())
}
}
}
}
#[cfg(test)]
mod tests {
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream};
use bytes::Bytes;
use futures::Future;
use futures::{Sink, Stream};
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
#[test]
fn wrong_proto_name() {
@ -205,7 +225,7 @@ mod tests {
.from_err()
.and_then(move |stream| Dialer::new(stream))
.and_then(move |dialer| {
let p = Bytes::from("invalid_name");
let p = b"invalid_name";
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
});

View File

@ -20,33 +20,35 @@
//! Contains the `Listener` wrapper, which allows raw communications with a dialer.
use bytes::Bytes;
use futures::{Async, AsyncSink, prelude::*, sink, stream::StreamFuture};
use bytes::{BufMut, Bytes, BytesMut};
use crate::length_delimited::LengthDelimited;
use crate::protocol::DialerToListenerMessage;
use crate::protocol::ListenerToDialerMessage;
use crate::protocol::MultistreamSelectError;
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use futures::{prelude::*, sink, stream::StreamFuture};
use log::{debug, trace};
use std::mem;
use std::{io, mem};
use tokio_codec::Encoder;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::encode;
use unsigned_varint::{encode, codec::Uvi};
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
/// accepts messages.
pub struct Listener<R> {
inner: LengthDelimited<Bytes, R>
pub struct Listener<R, N> {
inner: LengthDelimited<R, MessageEncoder<N>>
}
impl<R> Listener<R>
impl<R, N> Listener<R, N>
where
R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
/// future returns a `Listener`.
pub fn new(inner: R) -> ListenerFuture<R> {
let inner = LengthDelimited::new(inner);
pub fn new(inner: R) -> ListenerFuture<R, N> {
let codec = MessageEncoder(std::marker::PhantomData);
let inner = LengthDelimited::new(inner, codec);
ListenerFuture {
inner: ListenerFutureState::Await { inner: inner.into_future() }
}
@ -60,66 +62,20 @@ where
}
}
impl<R> Sink for Listener<R>
impl<R, N> Sink for Listener<R, N>
where
R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{
type SinkItem = ListenerToDialerMessage;
type SinkItem = ListenerToDialerMessage<N>;
type SinkError = MultistreamSelectError;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
match item {
ListenerToDialerMessage::ProtocolAck { name } => {
if !name.starts_with(b"/") {
debug!("invalid protocol name {:?}", name);
return Err(MultistreamSelectError::WrongProtocolName);
}
let mut protocol = Bytes::from(name);
protocol.extend_from_slice(&[b'\n']);
match self.inner.start_send(protocol) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(mut protocol)) => {
let protocol_len = protocol.len();
protocol.truncate(protocol_len - 1);
Ok(AsyncSink::NotReady(ListenerToDialerMessage::ProtocolAck {
name: protocol,
}))
}
Err(err) => Err(err.into()),
}
}
ListenerToDialerMessage::NotAvailable => {
match self.inner.start_send(Bytes::from(&b"na\n"[..])) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => {
Ok(AsyncSink::NotReady(ListenerToDialerMessage::NotAvailable))
}
Err(err) => Err(err.into()),
}
}
ListenerToDialerMessage::ProtocolsListResponse { list } => {
use std::iter;
let mut buf = encode::usize_buffer();
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
for elem in &list {
out_msg.extend(encode::usize(elem.len() + 1, &mut buf)); // +1 for '\n'
out_msg.extend_from_slice(elem);
out_msg.extend(iter::once(b'\n'));
}
match self.inner.start_send(Bytes::from(out_msg)) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => {
let m = ListenerToDialerMessage::ProtocolsListResponse { list };
Ok(AsyncSink::NotReady(m))
}
Err(err) => Err(err.into()),
}
}
match self.inner.start_send(Message::Body(item))? {
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
AsyncSink::NotReady(Message::Header) => unreachable!(),
AsyncSink::Ready => Ok(AsyncSink::Ready)
}
}
@ -134,11 +90,11 @@ where
}
}
impl<R> Stream for Listener<R>
impl<R, N> Stream for Listener<R, N>
where
R: AsyncRead + AsyncWrite,
{
type Item = DialerToListenerMessage;
type Item = DialerToListenerMessage<Bytes>;
type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
@ -168,22 +124,22 @@ where
/// Future, returned by `Listener::new` which performs the handshake and returns
/// the `Listener` if successful.
pub struct ListenerFuture<T: AsyncRead + AsyncWrite> {
inner: ListenerFutureState<T>
pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
inner: ListenerFutureState<T, N>
}
enum ListenerFutureState<T: AsyncRead + AsyncWrite> {
enum ListenerFutureState<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
Await {
inner: StreamFuture<LengthDelimited<Bytes, T>>
inner: StreamFuture<LengthDelimited<T, MessageEncoder<N>>>
},
Reply {
sender: sink::Send<LengthDelimited<Bytes, T>>
sender: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
},
Undefined
}
impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
type Item = Listener<T>;
impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N> {
type Item = Listener<T, N>;
type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
@ -204,7 +160,7 @@ impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
return Err(MultistreamSelectError::FailedHandshake)
}
trace!("sending back /multistream/<version> to finish the handshake");
let sender = socket.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF));
let sender = socket.send(Message::Header);
self.inner = ListenerFutureState::Reply { sender }
}
ListenerFutureState::Reply { mut sender } => {
@ -223,6 +179,66 @@ impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
}
}
/// tokio-codec `Encoder` handling `ListenerToDialerMessage` values.
struct MessageEncoder<N>(std::marker::PhantomData<N>);
enum Message<N> {
Header,
Body(ListenerToDialerMessage<N>)
}
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
type Item = Message<N>;
type Error = MultistreamSelectError;
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
match item {
Message::Header => {
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
Ok(())
}
Message::Body(ListenerToDialerMessage::ProtocolAck { name }) => {
if !name.as_ref().starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName)
}
let len = name.as_ref().len() + 1; // + 1 for \n
if len > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
Uvi::<usize>::default().encode(len, dest)?;
dest.reserve(len);
dest.put(name.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Message::Body(ListenerToDialerMessage::ProtocolsListResponse { list }) => {
let mut buf = encode::usize_buffer();
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
for e in &list {
if e.as_ref().len() + 1 > std::u16::MAX as usize {
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
}
out_msg.extend(encode::usize(e.as_ref().len() + 1, &mut buf)); // +1 for '\n'
out_msg.extend_from_slice(e.as_ref());
out_msg.push(b'\n')
}
let len = encode::usize(out_msg.len(), &mut buf);
dest.reserve(len.len() + out_msg.len());
dest.put(len);
dest.put(out_msg);
Ok(())
}
Message::Body(ListenerToDialerMessage::NotAvailable) => {
Uvi::<usize>::default().encode(3, dest)?;
dest.reserve(3);
dest.put(&b"na\n"[..]);
Ok(())
}
}
}
}
#[cfg(test)]
mod tests {
@ -250,7 +266,7 @@ mod tests {
let client = TcpStream::connect(&listener_addr)
.from_err()
.and_then(move |stream| Dialer::new(stream));
.and_then(move |stream| Dialer::<_, Bytes>::new(stream));
let mut rt = Runtime::new().unwrap();
match rt.block_on(server.join(client)) {

View File

@ -20,8 +20,6 @@
//! Contains lower-level structs to handle the multistream protocol.
use bytes::Bytes;
mod dialer;
mod error;
mod listener;
@ -34,14 +32,14 @@ pub use self::listener::{Listener, ListenerFuture};
/// Message sent from the dialer to the listener.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DialerToListenerMessage {
pub enum DialerToListenerMessage<N> {
/// The dialer wants us to use a protocol.
///
/// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start
/// communicating in the new protocol.
ProtocolRequest {
/// Name of the protocol.
name: Bytes,
name: N
},
/// The dialer requested the list of protocols that the listener supports.
@ -50,10 +48,10 @@ pub enum DialerToListenerMessage {
/// Message sent from the listener to the dialer.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ListenerToDialerMessage {
pub enum ListenerToDialerMessage<N> {
/// The protocol requested by the dialer is accepted. The socket immediately starts using the
/// new protocol.
ProtocolAck { name: Bytes },
ProtocolAck { name: N },
/// The protocol requested by the dialer is not supported or available.
NotAvailable,
@ -62,6 +60,7 @@ pub enum ListenerToDialerMessage {
ProtocolsListResponse {
/// The list of protocols.
// TODO: use some sort of iterator
list: Vec<Bytes>,
list: Vec<N>,
},
}

View File

@ -22,15 +22,13 @@
#![cfg(test)]
use crate::ProtocolChoiceError;
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage};
use crate::{dialer_select_proto, listener_select_proto};
use futures::prelude::*;
use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream};
use bytes::Bytes;
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
use futures::Future;
use futures::{Sink, Stream};
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage};
use crate::ProtocolChoiceError;
use crate::{dialer_select_proto, listener_select_proto};
/// Holds a `Vec` and satifies the iterator requirements of `listener_select_proto`.
struct VecRefIntoIter<T>(Vec<T>);
@ -68,7 +66,7 @@ fn negotiate_with_self_succeeds() {
.from_err()
.and_then(move |stream| Dialer::new(stream))
.and_then(move |dialer| {
let p = Bytes::from("/hello/1.0.0");
let p = b"/hello/1.0.0";
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
})
.and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e))