mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-06-21 13:51:33 +00:00
multistream-select: Less allocations. (#800)
This commit is contained in:
@ -224,6 +224,7 @@ type NameWrapIter<I> =
|
||||
std::iter::Map<I, fn(<I as Iterator>::Item) -> NameWrap<<I as Iterator>::Item>>;
|
||||
|
||||
/// Wrapper type to expose an `AsRef<[u8]>` impl for all types implementing `ProtocolName`.
|
||||
#[derive(Clone)]
|
||||
struct NameWrap<N>(N);
|
||||
|
||||
impl<N: ProtocolName> AsRef<[u8]> for NameWrap<N> {
|
||||
|
@ -91,7 +91,7 @@ impl<T: AsRef<[u8]>> ProtocolName for T {
|
||||
/// or both.
|
||||
pub trait UpgradeInfo {
|
||||
/// Opaque type representing a negotiable protocol.
|
||||
type Info: ProtocolName;
|
||||
type Info: ProtocolName + Clone;
|
||||
/// Iterator returned by `protocol_info`.
|
||||
type InfoIter: IntoIterator<Item = Self::Info>;
|
||||
|
||||
|
@ -21,9 +21,13 @@
|
||||
//! Contains the `dialer_select_proto` code, which allows selecting a protocol thanks to
|
||||
//! `multistream-select` for the dialer.
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{future::Either, prelude::*, sink, stream::StreamFuture};
|
||||
use crate::protocol::{Dialer, DialerFuture, DialerToListenerMessage, ListenerToDialerMessage};
|
||||
use futures::{future::Either, prelude::*, stream::StreamFuture};
|
||||
use crate::protocol::{
|
||||
Dialer,
|
||||
DialerFuture,
|
||||
DialerToListenerMessage,
|
||||
ListenerToDialerMessage
|
||||
};
|
||||
use log::trace;
|
||||
use std::mem;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
@ -44,7 +48,6 @@ pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPa
|
||||
/// remote, and the protocol name that we passed (so that you don't have to clone the name). On
|
||||
/// success, the function returns the identifier (of type `P`), plus the socket which now uses that
|
||||
/// chosen protocol.
|
||||
#[inline]
|
||||
pub fn dialer_select_proto<R, I>(inner: R, protocols: I) -> DialerSelectFuture<R, I::IntoIter>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
@ -64,12 +67,13 @@ where
|
||||
///
|
||||
/// Same as `dialer_select_proto`. Tries protocols one by one. The iterator doesn't need to produce
|
||||
/// match functions, because it's not needed.
|
||||
pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I,) -> DialerSelectSeq<R, I>
|
||||
pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I) -> DialerSelectSeq<R, I::IntoIter>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: Iterator,
|
||||
I: IntoIterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
{
|
||||
let protocols = protocols.into_iter();
|
||||
DialerSelectSeq {
|
||||
inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
|
||||
}
|
||||
@ -78,26 +82,37 @@ where
|
||||
|
||||
/// Future, returned by `dialer_select_proto_serial` which selects a protocol
|
||||
/// and dialer sequentially.
|
||||
pub struct DialerSelectSeq<R: AsyncRead + AsyncWrite, I: Iterator> {
|
||||
pub struct DialerSelectSeq<R, I>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
{
|
||||
inner: DialerSelectSeqState<R, I>
|
||||
}
|
||||
|
||||
enum DialerSelectSeqState<R: AsyncRead + AsyncWrite, I: Iterator> {
|
||||
enum DialerSelectSeqState<R, I>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
{
|
||||
AwaitDialer {
|
||||
dialer_fut: DialerFuture<R>,
|
||||
dialer_fut: DialerFuture<R, I::Item>,
|
||||
protocols: I
|
||||
},
|
||||
NextProtocol {
|
||||
dialer: Dialer<R>,
|
||||
dialer: Dialer<R, I::Item>,
|
||||
proto_name: I::Item,
|
||||
protocols: I
|
||||
},
|
||||
SendProtocol {
|
||||
sender: sink::Send<Dialer<R>>,
|
||||
FlushProtocol {
|
||||
dialer: Dialer<R, I::Item>,
|
||||
proto_name: I::Item,
|
||||
protocols: I
|
||||
},
|
||||
AwaitProtocol {
|
||||
stream: StreamFuture<Dialer<R>>,
|
||||
stream: StreamFuture<Dialer<R, I::Item>>,
|
||||
proto_name: I::Item,
|
||||
protocols: I
|
||||
},
|
||||
@ -106,9 +121,9 @@ enum DialerSelectSeqState<R: AsyncRead + AsyncWrite, I: Iterator> {
|
||||
|
||||
impl<R, I> Future for DialerSelectSeq<R, I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]>,
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]> + Clone
|
||||
{
|
||||
type Item = (I::Item, R);
|
||||
type Error = ProtocolChoiceError;
|
||||
@ -116,7 +131,7 @@ where
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
loop {
|
||||
match mem::replace(&mut self.inner, DialerSelectSeqState::Undefined) {
|
||||
DialerSelectSeqState::AwaitDialer { mut dialer_fut, protocols } => {
|
||||
DialerSelectSeqState::AwaitDialer { mut dialer_fut, mut protocols } => {
|
||||
let dialer = match dialer_fut.poll()? {
|
||||
Async::Ready(d) => d,
|
||||
Async::NotReady => {
|
||||
@ -124,42 +139,57 @@ where
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
};
|
||||
self.inner = DialerSelectSeqState::NextProtocol { dialer, protocols }
|
||||
let proto_name = protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?;
|
||||
self.inner = DialerSelectSeqState::NextProtocol {
|
||||
dialer,
|
||||
protocols,
|
||||
proto_name
|
||||
}
|
||||
DialerSelectSeqState::NextProtocol { dialer, mut protocols } => {
|
||||
let proto_name =
|
||||
protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?;
|
||||
}
|
||||
DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => {
|
||||
trace!("sending {:?}", proto_name.as_ref());
|
||||
let req = DialerToListenerMessage::ProtocolRequest {
|
||||
name: Bytes::from(proto_name.as_ref())
|
||||
name: proto_name.clone()
|
||||
};
|
||||
trace!("sending {:?}", req);
|
||||
let sender = dialer.send(req);
|
||||
self.inner = DialerSelectSeqState::SendProtocol {
|
||||
sender,
|
||||
match dialer.start_send(req)? {
|
||||
AsyncSink::Ready => {
|
||||
self.inner = DialerSelectSeqState::FlushProtocol {
|
||||
dialer,
|
||||
proto_name,
|
||||
protocols
|
||||
}
|
||||
}
|
||||
DialerSelectSeqState::SendProtocol { mut sender, proto_name, protocols } => {
|
||||
let dialer = match sender.poll()? {
|
||||
Async::Ready(d) => d,
|
||||
Async::NotReady => {
|
||||
self.inner = DialerSelectSeqState::SendProtocol {
|
||||
sender,
|
||||
proto_name,
|
||||
protocols
|
||||
AsyncSink::NotReady(_) => {
|
||||
self.inner = DialerSelectSeqState::NextProtocol {
|
||||
dialer,
|
||||
protocols,
|
||||
proto_name
|
||||
};
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
DialerSelectSeqState::FlushProtocol { mut dialer, proto_name, protocols } => {
|
||||
match dialer.poll_complete()? {
|
||||
Async::Ready(()) => {
|
||||
let stream = dialer.into_future();
|
||||
self.inner = DialerSelectSeqState::AwaitProtocol {
|
||||
stream,
|
||||
proto_name,
|
||||
protocols
|
||||
};
|
||||
}
|
||||
DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, protocols } => {
|
||||
}
|
||||
Async::NotReady => {
|
||||
self.inner = DialerSelectSeqState::FlushProtocol {
|
||||
dialer,
|
||||
proto_name,
|
||||
protocols
|
||||
};
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, mut protocols } => {
|
||||
let (m, r) = match stream.poll() {
|
||||
Ok(Async::Ready(x)) => x,
|
||||
Ok(Async::NotReady) => {
|
||||
@ -178,9 +208,15 @@ where
|
||||
if name.as_ref() == proto_name.as_ref() =>
|
||||
{
|
||||
return Ok(Async::Ready((proto_name, r.into_inner())))
|
||||
},
|
||||
}
|
||||
ListenerToDialerMessage::NotAvailable => {
|
||||
self.inner = DialerSelectSeqState::NextProtocol { dialer: r, protocols }
|
||||
let proto_name = protocols.next()
|
||||
.ok_or(ProtocolChoiceError::NoProtocolFound)?;
|
||||
self.inner = DialerSelectSeqState::NextProtocol {
|
||||
dialer: r,
|
||||
protocols,
|
||||
proto_name
|
||||
}
|
||||
}
|
||||
_ => return Err(ProtocolChoiceError::UnexpectedMessage)
|
||||
}
|
||||
@ -192,17 +228,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// Helps selecting a protocol amongst the ones supported.
|
||||
///
|
||||
/// Same as `dialer_select_proto`. Queries the list of supported protocols from the remote, then
|
||||
/// chooses the most appropriate one.
|
||||
pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I>
|
||||
pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I::IntoIter>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]>,
|
||||
R: AsyncRead + AsyncWrite
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: IntoIterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
{
|
||||
let protocols = protocols.into_iter();
|
||||
DialerSelectPar {
|
||||
inner: DialerSelectParState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
|
||||
}
|
||||
@ -212,29 +248,47 @@ where
|
||||
/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in
|
||||
/// parellel, by first requesting the liste of protocols supported by the remote endpoint and
|
||||
/// then selecting the most appropriate one by applying a match predicate to the result.
|
||||
pub struct DialerSelectPar<R: AsyncRead + AsyncWrite, I: Iterator> {
|
||||
pub struct DialerSelectPar<R, I>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
{
|
||||
inner: DialerSelectParState<R, I>
|
||||
}
|
||||
|
||||
enum DialerSelectParState<R: AsyncRead + AsyncWrite, I: Iterator> {
|
||||
enum DialerSelectParState<R, I>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]>
|
||||
{
|
||||
AwaitDialer {
|
||||
dialer_fut: DialerFuture<R>,
|
||||
dialer_fut: DialerFuture<R, I::Item>,
|
||||
protocols: I
|
||||
},
|
||||
SendRequest {
|
||||
sender: sink::Send<Dialer<R>>,
|
||||
ProtocolList {
|
||||
dialer: Dialer<R, I::Item>,
|
||||
protocols: I
|
||||
},
|
||||
AwaitResponse {
|
||||
stream: StreamFuture<Dialer<R>>,
|
||||
FlushListRequest {
|
||||
dialer: Dialer<R, I::Item>,
|
||||
protocols: I
|
||||
},
|
||||
SendProtocol {
|
||||
sender: sink::Send<Dialer<R>>,
|
||||
AwaitListResponse {
|
||||
stream: StreamFuture<Dialer<R, I::Item>>,
|
||||
protocols: I,
|
||||
},
|
||||
Protocol {
|
||||
dialer: Dialer<R, I::Item>,
|
||||
proto_name: I::Item
|
||||
},
|
||||
FlushProtocol {
|
||||
dialer: Dialer<R, I::Item>,
|
||||
proto_name: I::Item
|
||||
},
|
||||
AwaitProtocol {
|
||||
stream: StreamFuture<Dialer<R>>,
|
||||
stream: StreamFuture<Dialer<R, I::Item>>,
|
||||
proto_name: I::Item
|
||||
},
|
||||
Undefined
|
||||
@ -242,9 +296,9 @@ enum DialerSelectParState<R: AsyncRead + AsyncWrite, I: Iterator> {
|
||||
|
||||
impl<R, I> Future for DialerSelectPar<R, I>
|
||||
where
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]>,
|
||||
R: AsyncRead + AsyncWrite,
|
||||
I: Iterator,
|
||||
I::Item: AsRef<[u8]> + Clone
|
||||
{
|
||||
type Item = (I::Item, R);
|
||||
type Error = ProtocolChoiceError;
|
||||
@ -253,41 +307,63 @@ where
|
||||
loop {
|
||||
match mem::replace(&mut self.inner, DialerSelectParState::Undefined) {
|
||||
DialerSelectParState::AwaitDialer { mut dialer_fut, protocols } => {
|
||||
let dialer = match dialer_fut.poll()? {
|
||||
Async::Ready(d) => d,
|
||||
match dialer_fut.poll()? {
|
||||
Async::Ready(dialer) => {
|
||||
self.inner = DialerSelectParState::ProtocolList { dialer, protocols }
|
||||
}
|
||||
Async::NotReady => {
|
||||
self.inner = DialerSelectParState::AwaitDialer { dialer_fut, protocols };
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
};
|
||||
trace!("requesting protocols list");
|
||||
let sender = dialer.send(DialerToListenerMessage::ProtocolsListRequest);
|
||||
self.inner = DialerSelectParState::SendRequest { sender, protocols };
|
||||
}
|
||||
DialerSelectParState::SendRequest { mut sender, protocols } => {
|
||||
let dialer = match sender.poll()? {
|
||||
Async::Ready(d) => d,
|
||||
Async::NotReady => {
|
||||
self.inner = DialerSelectParState::SendRequest { sender, protocols };
|
||||
}
|
||||
DialerSelectParState::ProtocolList { mut dialer, protocols } => {
|
||||
trace!("requesting protocols list");
|
||||
match dialer.start_send(DialerToListenerMessage::ProtocolsListRequest)? {
|
||||
AsyncSink::Ready => {
|
||||
self.inner = DialerSelectParState::FlushListRequest {
|
||||
dialer,
|
||||
protocols
|
||||
}
|
||||
}
|
||||
AsyncSink::NotReady(_) => {
|
||||
self.inner = DialerSelectParState::ProtocolList { dialer, protocols };
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
};
|
||||
let stream = dialer.into_future();
|
||||
self.inner = DialerSelectParState::AwaitResponse { stream, protocols };
|
||||
}
|
||||
DialerSelectParState::AwaitResponse { mut stream, protocols } => {
|
||||
let (m, d) = match stream.poll() {
|
||||
}
|
||||
DialerSelectParState::FlushListRequest { mut dialer, protocols } => {
|
||||
match dialer.poll_complete()? {
|
||||
Async::Ready(()) => {
|
||||
self.inner = DialerSelectParState::AwaitListResponse {
|
||||
stream: dialer.into_future(),
|
||||
protocols
|
||||
}
|
||||
}
|
||||
Async::NotReady => {
|
||||
self.inner = DialerSelectParState::FlushListRequest {
|
||||
dialer,
|
||||
protocols
|
||||
};
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
DialerSelectParState::AwaitListResponse { mut stream, protocols } => {
|
||||
let (resp, dialer) = match stream.poll() {
|
||||
Ok(Async::Ready(x)) => x,
|
||||
Ok(Async::NotReady) => {
|
||||
self.inner = DialerSelectParState::AwaitResponse { stream, protocols };
|
||||
self.inner = DialerSelectParState::AwaitListResponse { stream, protocols };
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
|
||||
};
|
||||
trace!("protocols list response: {:?}", m);
|
||||
let list = match m {
|
||||
Some(ListenerToDialerMessage::ProtocolsListResponse { list }) => list,
|
||||
_ => return Err(ProtocolChoiceError::UnexpectedMessage),
|
||||
trace!("protocols list response: {:?}", resp);
|
||||
let list =
|
||||
if let Some(ListenerToDialerMessage::ProtocolsListResponse { list }) = resp {
|
||||
list
|
||||
} else {
|
||||
return Err(ProtocolChoiceError::UnexpectedMessage)
|
||||
};
|
||||
let mut found = None;
|
||||
for local_name in protocols {
|
||||
@ -302,47 +378,52 @@ where
|
||||
}
|
||||
}
|
||||
let proto_name = found.ok_or(ProtocolChoiceError::NoProtocolFound)?;
|
||||
trace!("sending {:?}", proto_name.as_ref());
|
||||
let sender = d.send(DialerToListenerMessage::ProtocolRequest {
|
||||
name: Bytes::from(proto_name.as_ref())
|
||||
});
|
||||
self.inner = DialerSelectParState::SendProtocol { sender, proto_name };
|
||||
self.inner = DialerSelectParState::Protocol { dialer, proto_name }
|
||||
}
|
||||
DialerSelectParState::SendProtocol { mut sender, proto_name } => {
|
||||
let dialer = match sender.poll()? {
|
||||
Async::Ready(d) => d,
|
||||
Async::NotReady => {
|
||||
self.inner = DialerSelectParState::SendProtocol {
|
||||
sender,
|
||||
proto_name
|
||||
DialerSelectParState::Protocol { mut dialer, proto_name } => {
|
||||
trace!("requesting protocol: {:?}", proto_name.as_ref());
|
||||
let req = DialerToListenerMessage::ProtocolRequest {
|
||||
name: proto_name.clone()
|
||||
};
|
||||
match dialer.start_send(req)? {
|
||||
AsyncSink::Ready => {
|
||||
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }
|
||||
}
|
||||
AsyncSink::NotReady(_) => {
|
||||
self.inner = DialerSelectParState::Protocol { dialer, proto_name };
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
};
|
||||
let stream = dialer.into_future();
|
||||
}
|
||||
}
|
||||
DialerSelectParState::FlushProtocol { mut dialer, proto_name } => {
|
||||
match dialer.poll_complete()? {
|
||||
Async::Ready(()) => {
|
||||
self.inner = DialerSelectParState::AwaitProtocol {
|
||||
stream,
|
||||
stream: dialer.into_future(),
|
||||
proto_name
|
||||
};
|
||||
}
|
||||
}
|
||||
Async::NotReady => {
|
||||
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name };
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
}
|
||||
}
|
||||
DialerSelectParState::AwaitProtocol { mut stream, proto_name } => {
|
||||
let (m, r) = match stream.poll() {
|
||||
let (resp, dialer) = match stream.poll() {
|
||||
Ok(Async::Ready(x)) => x,
|
||||
Ok(Async::NotReady) => {
|
||||
self.inner = DialerSelectParState::AwaitProtocol {
|
||||
stream,
|
||||
proto_name
|
||||
};
|
||||
self.inner = DialerSelectParState::AwaitProtocol { stream, proto_name };
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
|
||||
};
|
||||
trace!("received {:?}", m);
|
||||
match m {
|
||||
trace!("received {:?}", resp);
|
||||
match resp {
|
||||
Some(ListenerToDialerMessage::ProtocolAck { ref name })
|
||||
if name.as_ref() == proto_name.as_ref() =>
|
||||
{
|
||||
return Ok(Async::Ready((proto_name, r.into_inner())))
|
||||
return Ok(Async::Ready((proto_name, dialer.into_inner())))
|
||||
}
|
||||
_ => return Err(ProtocolChoiceError::UnexpectedMessage)
|
||||
}
|
||||
@ -354,4 +435,3 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -18,21 +18,22 @@
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{Async, Poll, Sink, StartSend, Stream};
|
||||
use smallvec::SmallVec;
|
||||
use std::{io::{Error as IoError, ErrorKind as IoErrorKind}, marker::PhantomData, u16};
|
||||
use tokio_codec::FramedWrite;
|
||||
use std::{io, u16};
|
||||
use tokio_codec::{Encoder, FramedWrite};
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use unsigned_varint::codec::UviBytes;
|
||||
use unsigned_varint::decode;
|
||||
|
||||
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
|
||||
/// and write unsigned-varint prefixed frames.
|
||||
///
|
||||
/// We purposely only support a frame length of under 64kiB. Frames mostly consist
|
||||
/// in a short protocol name, which is highly unlikely to be more than 64kiB long.
|
||||
pub struct LengthDelimited<I, S> {
|
||||
pub struct LengthDelimited<R, C> {
|
||||
// The inner socket where data is pulled from.
|
||||
inner: FramedWrite<S, UviBytes>,
|
||||
inner: FramedWrite<R, C>,
|
||||
// Intermediary buffer where we put either the length of the next frame of data, or the frame
|
||||
// of data itself before it is returned.
|
||||
// Must always contain enough space to read data from `inner`.
|
||||
@ -40,8 +41,7 @@ pub struct LengthDelimited<I, S> {
|
||||
// Number of bytes within `internal_buffer` that contain valid data.
|
||||
internal_buffer_pos: usize,
|
||||
// State of the decoder.
|
||||
state: State,
|
||||
marker: PhantomData<I>,
|
||||
state: State
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
|
||||
@ -52,24 +52,21 @@ enum State {
|
||||
ReadingData { frame_len: u16 },
|
||||
}
|
||||
|
||||
impl<I, S> LengthDelimited<I, S>
|
||||
impl<R, C> LengthDelimited<R, C>
|
||||
where
|
||||
S: AsyncWrite
|
||||
R: AsyncWrite,
|
||||
C: Encoder
|
||||
{
|
||||
pub fn new(inner: S) -> LengthDelimited<I, S> {
|
||||
let mut encoder = UviBytes::default();
|
||||
encoder.set_max_len(usize::from(u16::MAX));
|
||||
|
||||
pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
|
||||
LengthDelimited {
|
||||
inner: FramedWrite::new(inner, encoder),
|
||||
inner: FramedWrite::new(inner, codec),
|
||||
internal_buffer: {
|
||||
let mut v = SmallVec::new();
|
||||
v.push(0);
|
||||
v
|
||||
},
|
||||
internal_buffer_pos: 0,
|
||||
state: State::ReadingLength,
|
||||
marker: PhantomData,
|
||||
state: State::ReadingLength
|
||||
}
|
||||
}
|
||||
|
||||
@ -85,20 +82,19 @@ where
|
||||
/// the modifiers provided by the `futures` crate) will always leave the object in a state in
|
||||
/// which `into_inner()` will not panic.
|
||||
#[inline]
|
||||
pub fn into_inner(self) -> S {
|
||||
pub fn into_inner(self) -> R {
|
||||
assert_eq!(self.state, State::ReadingLength);
|
||||
assert_eq!(self.internal_buffer_pos, 0);
|
||||
self.inner.into_inner()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, S> Stream for LengthDelimited<I, S>
|
||||
impl<R, C> Stream for LengthDelimited<R, C>
|
||||
where
|
||||
S: AsyncRead,
|
||||
I: for<'r> From<&'r [u8]>,
|
||||
R: AsyncRead
|
||||
{
|
||||
type Item = I;
|
||||
type Error = IoError;
|
||||
type Item = Bytes;
|
||||
type Error = io::Error;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
loop {
|
||||
@ -107,23 +103,21 @@ where
|
||||
|
||||
match self.state {
|
||||
State::ReadingLength => {
|
||||
match self.inner
|
||||
.get_mut()
|
||||
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
|
||||
{
|
||||
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
|
||||
match self.inner.get_mut().read(slice) {
|
||||
Ok(0) => {
|
||||
// EOF
|
||||
if self.internal_buffer_pos == 0 {
|
||||
return Ok(Async::Ready(None));
|
||||
} else {
|
||||
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
|
||||
return Err(io::ErrorKind::UnexpectedEof.into());
|
||||
}
|
||||
}
|
||||
Ok(n) => {
|
||||
debug_assert_eq!(n, 1);
|
||||
self.internal_buffer_pos += n;
|
||||
}
|
||||
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
Err(err) => {
|
||||
@ -136,7 +130,10 @@ where
|
||||
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 {
|
||||
// End of length prefix. Most of the time we will switch to reading data,
|
||||
// but we need to handle a few corner cases first.
|
||||
let frame_len = decode_length_prefix(&self.internal_buffer);
|
||||
let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| {
|
||||
log::debug!("invalid length prefix: {}", e);
|
||||
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
|
||||
})?;
|
||||
|
||||
if frame_len >= 1 {
|
||||
self.state = State::ReadingData { frame_len };
|
||||
@ -154,33 +151,22 @@ where
|
||||
}
|
||||
} else if self.internal_buffer_pos >= 2 {
|
||||
// Length prefix is too long. See module doc for info about max frame len.
|
||||
return Err(IoError::new(
|
||||
IoErrorKind::InvalidData,
|
||||
"frame length too long",
|
||||
));
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long"));
|
||||
} else {
|
||||
// Prepare for next read.
|
||||
self.internal_buffer.push(0);
|
||||
}
|
||||
}
|
||||
|
||||
State::ReadingData { frame_len } => {
|
||||
match self.inner
|
||||
.get_mut()
|
||||
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
|
||||
{
|
||||
Ok(0) => {
|
||||
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
|
||||
}
|
||||
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
|
||||
match self.inner.get_mut().read(slice) {
|
||||
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
|
||||
Ok(n) => self.internal_buffer_pos += n,
|
||||
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
|
||||
return Ok(Async::NotReady);
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(err);
|
||||
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
|
||||
return Ok(Async::NotReady)
|
||||
}
|
||||
Err(err) => return Err(err)
|
||||
};
|
||||
|
||||
if self.internal_buffer_pos >= frame_len as usize {
|
||||
// Finished reading the frame of data.
|
||||
self.state = State::ReadingLength;
|
||||
@ -196,12 +182,13 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, S> Sink for LengthDelimited<I, S>
|
||||
impl<R, C> Sink for LengthDelimited<R, C>
|
||||
where
|
||||
S: AsyncWrite
|
||||
R: AsyncWrite,
|
||||
C: Encoder
|
||||
{
|
||||
type SinkItem = <FramedWrite<S, UviBytes> as Sink>::SinkItem;
|
||||
type SinkError = <FramedWrite<S, UviBytes> as Sink>::SinkError;
|
||||
type SinkItem = <FramedWrite<R, C> as Sink>::SinkItem;
|
||||
type SinkError = <FramedWrite<R, C> as Sink>::SinkError;
|
||||
|
||||
#[inline]
|
||||
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||
@ -219,33 +206,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_length_prefix(buf: &[u8]) -> u16 {
|
||||
debug_assert!(buf.len() <= 2);
|
||||
|
||||
let mut sum = 0u16;
|
||||
|
||||
for &byte in buf.iter().rev() {
|
||||
let byte = byte & 0x7f;
|
||||
sum <<= 7;
|
||||
debug_assert!(sum.checked_add(u16::from(byte)).is_some());
|
||||
sum += u16::from(byte);
|
||||
}
|
||||
|
||||
sum
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::{Future, Stream};
|
||||
use crate::length_delimited::LengthDelimited;
|
||||
use std::io::Cursor;
|
||||
use std::io::ErrorKind;
|
||||
use std::io::{Cursor, ErrorKind};
|
||||
use unsigned_varint::codec::UviBytes;
|
||||
|
||||
#[test]
|
||||
fn basic_read() {
|
||||
let data = vec![6, 9, 8, 7, 6, 5, 4];
|
||||
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
|
||||
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let recved = framed.collect().wait().unwrap();
|
||||
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
|
||||
}
|
||||
@ -253,8 +224,7 @@ mod tests {
|
||||
#[test]
|
||||
fn basic_read_two() {
|
||||
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
|
||||
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
|
||||
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let recved = framed.collect().wait().unwrap();
|
||||
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
|
||||
}
|
||||
@ -266,8 +236,7 @@ mod tests {
|
||||
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
|
||||
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
|
||||
data.extend(frame.clone().into_iter());
|
||||
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
|
||||
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let recved = framed
|
||||
.into_future()
|
||||
.map(|(m, _)| m)
|
||||
@ -281,24 +250,24 @@ mod tests {
|
||||
fn packet_len_too_long() {
|
||||
let mut data = vec![0x81, 0x81, 0x1];
|
||||
data.extend((0..16513).map(|_| 0));
|
||||
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
|
||||
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let recved = framed
|
||||
.into_future()
|
||||
.map(|(m, _)| m)
|
||||
.map_err(|(err, _)| err)
|
||||
.wait();
|
||||
match recved {
|
||||
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::InvalidData),
|
||||
_ => panic!(),
|
||||
|
||||
if let Err(io_err) = recved {
|
||||
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_frames() {
|
||||
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
|
||||
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
|
||||
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let recved = framed.collect().wait().unwrap();
|
||||
assert_eq!(
|
||||
recved,
|
||||
@ -315,36 +284,36 @@ mod tests {
|
||||
#[test]
|
||||
fn unexpected_eof_in_len() {
|
||||
let data = vec![0x89];
|
||||
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
|
||||
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let recved = framed.collect().wait();
|
||||
match recved {
|
||||
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
|
||||
_ => panic!(),
|
||||
if let Err(io_err) = recved {
|
||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unexpected_eof_in_data() {
|
||||
let data = vec![5];
|
||||
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
|
||||
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let recved = framed.collect().wait();
|
||||
match recved {
|
||||
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
|
||||
_ => panic!(),
|
||||
if let Err(io_err) = recved {
|
||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unexpected_eof_in_data2() {
|
||||
let data = vec![5, 9, 8, 7];
|
||||
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
|
||||
|
||||
let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
|
||||
let recved = framed.collect().wait();
|
||||
match recved {
|
||||
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
|
||||
_ => panic!(),
|
||||
if let Err(io_err) = recved {
|
||||
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -21,9 +21,13 @@
|
||||
//! Contains the `listener_select_proto` code, which allows selecting a protocol thanks to
|
||||
//! `multistream-select` for the listener.
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{prelude::*, sink, stream::StreamFuture};
|
||||
use crate::protocol::{DialerToListenerMessage, Listener, ListenerFuture, ListenerToDialerMessage};
|
||||
use crate::protocol::{
|
||||
DialerToListenerMessage,
|
||||
Listener,
|
||||
ListenerFuture,
|
||||
ListenerToDialerMessage
|
||||
};
|
||||
use log::{debug, trace};
|
||||
use std::mem;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
@ -58,27 +62,31 @@ where
|
||||
}
|
||||
|
||||
/// Future, returned by `listener_select_proto` which selects a protocol among the ones supported.
|
||||
pub struct ListenerSelectFuture<R: AsyncRead + AsyncWrite, I, X>
|
||||
pub struct ListenerSelectFuture<R, I, X>
|
||||
where
|
||||
for<'a> &'a I: IntoIterator<Item = X>
|
||||
R: AsyncRead + AsyncWrite,
|
||||
for<'a> &'a I: IntoIterator<Item = X>,
|
||||
X: AsRef<[u8]>
|
||||
{
|
||||
inner: ListenerSelectState<R, I, X>
|
||||
}
|
||||
|
||||
enum ListenerSelectState<R: AsyncRead + AsyncWrite, I, X>
|
||||
enum ListenerSelectState<R, I, X>
|
||||
where
|
||||
for<'a> &'a I: IntoIterator<Item = X>
|
||||
R: AsyncRead + AsyncWrite,
|
||||
for<'a> &'a I: IntoIterator<Item = X>,
|
||||
X: AsRef<[u8]>
|
||||
{
|
||||
AwaitListener {
|
||||
listener_fut: ListenerFuture<R>,
|
||||
listener_fut: ListenerFuture<R, X>,
|
||||
protocols: I
|
||||
},
|
||||
Incoming {
|
||||
stream: StreamFuture<Listener<R>>,
|
||||
stream: StreamFuture<Listener<R, X>>,
|
||||
protocols: I
|
||||
},
|
||||
Outgoing {
|
||||
sender: sink::Send<Listener<R>>,
|
||||
sender: sink::Send<Listener<R, X>>,
|
||||
protocols: I,
|
||||
outcome: Option<X>
|
||||
},
|
||||
@ -87,9 +95,9 @@ where
|
||||
|
||||
impl<R, I, X> Future for ListenerSelectFuture<R, I, X>
|
||||
where
|
||||
for<'a> &'a I: IntoIterator<Item = X>,
|
||||
R: AsyncRead + AsyncWrite,
|
||||
X: AsRef<[u8]>
|
||||
for<'a> &'a I: IntoIterator<Item = X>,
|
||||
X: AsRef<[u8]> + Clone
|
||||
{
|
||||
type Item = (X, R, I);
|
||||
type Error = ProtocolChoiceError;
|
||||
@ -119,10 +127,12 @@ where
|
||||
};
|
||||
match msg {
|
||||
Some(DialerToListenerMessage::ProtocolsListRequest) => {
|
||||
let msg = ListenerToDialerMessage::ProtocolsListResponse {
|
||||
list: protocols.into_iter().map(|x| Bytes::from(x.as_ref())).collect(),
|
||||
};
|
||||
trace!("protocols list response: {:?}", msg);
|
||||
trace!("protocols list response: {:?}", protocols
|
||||
.into_iter()
|
||||
.map(|p| p.as_ref().into())
|
||||
.collect::<Vec<Vec<u8>>>());
|
||||
let list = protocols.into_iter().collect();
|
||||
let msg = ListenerToDialerMessage::ProtocolsListResponse { list };
|
||||
let sender = listener.send(msg);
|
||||
self.inner = ListenerSelectState::Outgoing {
|
||||
sender,
|
||||
@ -135,12 +145,14 @@ where
|
||||
let mut send_back = ListenerToDialerMessage::NotAvailable;
|
||||
for supported in &protocols {
|
||||
if name.as_ref() == supported.as_ref() {
|
||||
send_back = ListenerToDialerMessage::ProtocolAck {name: name.clone()};
|
||||
send_back = ListenerToDialerMessage::ProtocolAck {
|
||||
name: supported.clone()
|
||||
};
|
||||
outcome = Some(supported);
|
||||
break;
|
||||
}
|
||||
}
|
||||
trace!("requested: {:?}, response: {:?}", name, send_back);
|
||||
trace!("requested: {:?}, supported: {}", name, outcome.is_some());
|
||||
let sender = listener.send(send_back);
|
||||
self.inner = ListenerSelectState::Outgoing { sender, protocols, outcome }
|
||||
}
|
||||
|
@ -20,34 +20,35 @@
|
||||
|
||||
//! Contains the `Dialer` wrapper, which allows raw communications with a listener.
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{prelude::*, sink, Async, AsyncSink, StartSend, try_ready};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use crate::length_delimited::LengthDelimited;
|
||||
use crate::protocol::DialerToListenerMessage;
|
||||
use crate::protocol::ListenerToDialerMessage;
|
||||
use crate::protocol::MultistreamSelectError;
|
||||
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
|
||||
use futures::{prelude::*, sink, Async, StartSend, try_ready};
|
||||
use std::io;
|
||||
use tokio_codec::Encoder;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use unsigned_varint::decode;
|
||||
use unsigned_varint::{decode, codec::Uvi};
|
||||
|
||||
|
||||
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the dialer's side. Produces and
|
||||
/// accepts messages.
|
||||
pub struct Dialer<R> {
|
||||
inner: LengthDelimited<Bytes, R>,
|
||||
handshake_finished: bool,
|
||||
/// Wraps around a `AsyncRead+AsyncWrite`.
|
||||
/// Assumes that we're on the dialer's side. Produces and accepts messages.
|
||||
pub struct Dialer<R, N> {
|
||||
inner: LengthDelimited<R, MessageEncoder<N>>,
|
||||
handshake_finished: bool
|
||||
}
|
||||
|
||||
impl<R> Dialer<R>
|
||||
impl<R, N> Dialer<R, N>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
N: AsRef<[u8]>
|
||||
{
|
||||
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
|
||||
/// future returns a `Dialer`.
|
||||
pub fn new(inner: R) -> DialerFuture<R> {
|
||||
let sender = LengthDelimited::new(inner);
|
||||
pub fn new(inner: R) -> DialerFuture<R, N> {
|
||||
let codec = MessageEncoder(std::marker::PhantomData);
|
||||
let sender = LengthDelimited::new(inner, codec);
|
||||
DialerFuture {
|
||||
inner: sender.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF))
|
||||
inner: sender.send(Message::Header)
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,43 +59,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Sink for Dialer<R>
|
||||
impl<R, N> Sink for Dialer<R, N>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
N: AsRef<[u8]>
|
||||
{
|
||||
type SinkItem = DialerToListenerMessage;
|
||||
type SinkItem = DialerToListenerMessage<N>;
|
||||
type SinkError = MultistreamSelectError;
|
||||
|
||||
#[inline]
|
||||
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||
match item {
|
||||
DialerToListenerMessage::ProtocolRequest { name } => {
|
||||
if !name.starts_with(b"/") {
|
||||
return Err(MultistreamSelectError::WrongProtocolName);
|
||||
}
|
||||
let mut protocol = Bytes::from(name);
|
||||
protocol.extend_from_slice(&[b'\n']);
|
||||
match self.inner.start_send(protocol) {
|
||||
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
|
||||
Ok(AsyncSink::NotReady(mut protocol)) => {
|
||||
let protocol_len = protocol.len();
|
||||
protocol.truncate(protocol_len - 1);
|
||||
Ok(AsyncSink::NotReady(
|
||||
DialerToListenerMessage::ProtocolRequest { name: protocol },
|
||||
))
|
||||
}
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
DialerToListenerMessage::ProtocolsListRequest => {
|
||||
match self.inner.start_send(Bytes::from(&b"ls\n"[..])) {
|
||||
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
|
||||
Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(
|
||||
DialerToListenerMessage::ProtocolsListRequest,
|
||||
)),
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
match self.inner.start_send(Message::Body(item))? {
|
||||
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
|
||||
AsyncSink::NotReady(Message::Header) => unreachable!(),
|
||||
AsyncSink::Ready => Ok(AsyncSink::Ready)
|
||||
}
|
||||
}
|
||||
|
||||
@ -109,11 +87,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Stream for Dialer<R>
|
||||
impl<R, N> Stream for Dialer<R, N>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
R: AsyncRead + AsyncWrite
|
||||
{
|
||||
type Item = ListenerToDialerMessage;
|
||||
type Item = ListenerToDialerMessage<Bytes>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
@ -138,7 +116,7 @@ where
|
||||
let frame_len = frame.len();
|
||||
let protocol = frame.split_to(frame_len - 1);
|
||||
return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck {
|
||||
name: protocol,
|
||||
name: protocol
|
||||
})));
|
||||
} else if frame == b"na\n"[..] {
|
||||
return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable)));
|
||||
@ -166,12 +144,12 @@ where
|
||||
}
|
||||
|
||||
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
|
||||
pub struct DialerFuture<T: AsyncWrite> {
|
||||
inner: sink::Send<LengthDelimited<Bytes, T>>
|
||||
pub struct DialerFuture<T: AsyncWrite, N: AsRef<[u8]>> {
|
||||
inner: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
|
||||
}
|
||||
|
||||
impl<T: AsyncWrite> Future for DialerFuture<T> {
|
||||
type Item = Dialer<T>;
|
||||
impl<T: AsyncWrite, N: AsRef<[u8]>> Future for DialerFuture<T, N> {
|
||||
type Item = Dialer<T, N>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
@ -180,15 +158,57 @@ impl<T: AsyncWrite> Future for DialerFuture<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// tokio-codec `Encoder` handling `DialerToListenerMessage` values.
|
||||
struct MessageEncoder<N>(std::marker::PhantomData<N>);
|
||||
|
||||
enum Message<N> {
|
||||
Header,
|
||||
Body(DialerToListenerMessage<N>)
|
||||
}
|
||||
|
||||
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
|
||||
type Item = Message<N>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
match item {
|
||||
Message::Header => {
|
||||
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
|
||||
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
|
||||
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(DialerToListenerMessage::ProtocolRequest { name }) => {
|
||||
if !name.as_ref().starts_with(b"/") {
|
||||
return Err(MultistreamSelectError::WrongProtocolName)
|
||||
}
|
||||
let len = name.as_ref().len() + 1; // + 1 for \n
|
||||
if len > std::u16::MAX as usize {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
|
||||
}
|
||||
Uvi::<usize>::default().encode(len, dest)?;
|
||||
dest.reserve(len);
|
||||
dest.put(name.as_ref());
|
||||
dest.put(&b"\n"[..]);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(DialerToListenerMessage::ProtocolsListRequest) => {
|
||||
Uvi::<usize>::default().encode(3, dest)?;
|
||||
dest.reserve(3);
|
||||
dest.put(&b"ls\n"[..]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_tcp::{TcpListener, TcpStream};
|
||||
use bytes::Bytes;
|
||||
use futures::Future;
|
||||
use futures::{Sink, Stream};
|
||||
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
|
||||
|
||||
#[test]
|
||||
fn wrong_proto_name() {
|
||||
@ -205,7 +225,7 @@ mod tests {
|
||||
.from_err()
|
||||
.and_then(move |stream| Dialer::new(stream))
|
||||
.and_then(move |dialer| {
|
||||
let p = Bytes::from("invalid_name");
|
||||
let p = b"invalid_name";
|
||||
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
|
||||
});
|
||||
|
||||
|
@ -20,33 +20,35 @@
|
||||
|
||||
//! Contains the `Listener` wrapper, which allows raw communications with a dialer.
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::{Async, AsyncSink, prelude::*, sink, stream::StreamFuture};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use crate::length_delimited::LengthDelimited;
|
||||
use crate::protocol::DialerToListenerMessage;
|
||||
use crate::protocol::ListenerToDialerMessage;
|
||||
use crate::protocol::MultistreamSelectError;
|
||||
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
|
||||
use futures::{prelude::*, sink, stream::StreamFuture};
|
||||
use log::{debug, trace};
|
||||
use std::mem;
|
||||
use std::{io, mem};
|
||||
use tokio_codec::Encoder;
|
||||
use tokio_io::{AsyncRead, AsyncWrite};
|
||||
use unsigned_varint::encode;
|
||||
|
||||
use unsigned_varint::{encode, codec::Uvi};
|
||||
|
||||
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
|
||||
/// accepts messages.
|
||||
pub struct Listener<R> {
|
||||
inner: LengthDelimited<Bytes, R>
|
||||
pub struct Listener<R, N> {
|
||||
inner: LengthDelimited<R, MessageEncoder<N>>
|
||||
}
|
||||
|
||||
impl<R> Listener<R>
|
||||
impl<R, N> Listener<R, N>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
N: AsRef<[u8]>
|
||||
{
|
||||
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
|
||||
/// future returns a `Listener`.
|
||||
pub fn new(inner: R) -> ListenerFuture<R> {
|
||||
let inner = LengthDelimited::new(inner);
|
||||
pub fn new(inner: R) -> ListenerFuture<R, N> {
|
||||
let codec = MessageEncoder(std::marker::PhantomData);
|
||||
let inner = LengthDelimited::new(inner, codec);
|
||||
ListenerFuture {
|
||||
inner: ListenerFutureState::Await { inner: inner.into_future() }
|
||||
}
|
||||
@ -60,66 +62,20 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Sink for Listener<R>
|
||||
impl<R, N> Sink for Listener<R, N>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
N: AsRef<[u8]>
|
||||
{
|
||||
type SinkItem = ListenerToDialerMessage;
|
||||
type SinkItem = ListenerToDialerMessage<N>;
|
||||
type SinkError = MultistreamSelectError;
|
||||
|
||||
#[inline]
|
||||
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
|
||||
match item {
|
||||
ListenerToDialerMessage::ProtocolAck { name } => {
|
||||
if !name.starts_with(b"/") {
|
||||
debug!("invalid protocol name {:?}", name);
|
||||
return Err(MultistreamSelectError::WrongProtocolName);
|
||||
}
|
||||
let mut protocol = Bytes::from(name);
|
||||
protocol.extend_from_slice(&[b'\n']);
|
||||
match self.inner.start_send(protocol) {
|
||||
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
|
||||
Ok(AsyncSink::NotReady(mut protocol)) => {
|
||||
let protocol_len = protocol.len();
|
||||
protocol.truncate(protocol_len - 1);
|
||||
Ok(AsyncSink::NotReady(ListenerToDialerMessage::ProtocolAck {
|
||||
name: protocol,
|
||||
}))
|
||||
}
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
ListenerToDialerMessage::NotAvailable => {
|
||||
match self.inner.start_send(Bytes::from(&b"na\n"[..])) {
|
||||
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
|
||||
Ok(AsyncSink::NotReady(_)) => {
|
||||
Ok(AsyncSink::NotReady(ListenerToDialerMessage::NotAvailable))
|
||||
}
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
|
||||
ListenerToDialerMessage::ProtocolsListResponse { list } => {
|
||||
use std::iter;
|
||||
|
||||
let mut buf = encode::usize_buffer();
|
||||
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
|
||||
for elem in &list {
|
||||
out_msg.extend(encode::usize(elem.len() + 1, &mut buf)); // +1 for '\n'
|
||||
out_msg.extend_from_slice(elem);
|
||||
out_msg.extend(iter::once(b'\n'));
|
||||
}
|
||||
|
||||
match self.inner.start_send(Bytes::from(out_msg)) {
|
||||
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
|
||||
Ok(AsyncSink::NotReady(_)) => {
|
||||
let m = ListenerToDialerMessage::ProtocolsListResponse { list };
|
||||
Ok(AsyncSink::NotReady(m))
|
||||
}
|
||||
Err(err) => Err(err.into()),
|
||||
}
|
||||
}
|
||||
match self.inner.start_send(Message::Body(item))? {
|
||||
AsyncSink::NotReady(Message::Body(item)) => Ok(AsyncSink::NotReady(item)),
|
||||
AsyncSink::NotReady(Message::Header) => unreachable!(),
|
||||
AsyncSink::Ready => Ok(AsyncSink::Ready)
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,11 +90,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<R> Stream for Listener<R>
|
||||
impl<R, N> Stream for Listener<R, N>
|
||||
where
|
||||
R: AsyncRead + AsyncWrite,
|
||||
{
|
||||
type Item = DialerToListenerMessage;
|
||||
type Item = DialerToListenerMessage<Bytes>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
|
||||
@ -168,22 +124,22 @@ where
|
||||
|
||||
/// Future, returned by `Listener::new` which performs the handshake and returns
|
||||
/// the `Listener` if successful.
|
||||
pub struct ListenerFuture<T: AsyncRead + AsyncWrite> {
|
||||
inner: ListenerFutureState<T>
|
||||
pub struct ListenerFuture<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
|
||||
inner: ListenerFutureState<T, N>
|
||||
}
|
||||
|
||||
enum ListenerFutureState<T: AsyncRead + AsyncWrite> {
|
||||
enum ListenerFutureState<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> {
|
||||
Await {
|
||||
inner: StreamFuture<LengthDelimited<Bytes, T>>
|
||||
inner: StreamFuture<LengthDelimited<T, MessageEncoder<N>>>
|
||||
},
|
||||
Reply {
|
||||
sender: sink::Send<LengthDelimited<Bytes, T>>
|
||||
sender: sink::Send<LengthDelimited<T, MessageEncoder<N>>>
|
||||
},
|
||||
Undefined
|
||||
}
|
||||
|
||||
impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
|
||||
type Item = Listener<T>;
|
||||
impl<T: AsyncRead + AsyncWrite, N: AsRef<[u8]>> Future for ListenerFuture<T, N> {
|
||||
type Item = Listener<T, N>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
|
||||
@ -204,7 +160,7 @@ impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
|
||||
return Err(MultistreamSelectError::FailedHandshake)
|
||||
}
|
||||
trace!("sending back /multistream/<version> to finish the handshake");
|
||||
let sender = socket.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF));
|
||||
let sender = socket.send(Message::Header);
|
||||
self.inner = ListenerFutureState::Reply { sender }
|
||||
}
|
||||
ListenerFutureState::Reply { mut sender } => {
|
||||
@ -223,6 +179,66 @@ impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// tokio-codec `Encoder` handling `ListenerToDialerMessage` values.
|
||||
struct MessageEncoder<N>(std::marker::PhantomData<N>);
|
||||
|
||||
enum Message<N> {
|
||||
Header,
|
||||
Body(ListenerToDialerMessage<N>)
|
||||
}
|
||||
|
||||
impl<N: AsRef<[u8]>> Encoder for MessageEncoder<N> {
|
||||
type Item = Message<N>;
|
||||
type Error = MultistreamSelectError;
|
||||
|
||||
fn encode(&mut self, item: Self::Item, dest: &mut BytesMut) -> Result<(), Self::Error> {
|
||||
match item {
|
||||
Message::Header => {
|
||||
Uvi::<usize>::default().encode(MULTISTREAM_PROTOCOL_WITH_LF.len(), dest)?;
|
||||
dest.reserve(MULTISTREAM_PROTOCOL_WITH_LF.len());
|
||||
dest.put(MULTISTREAM_PROTOCOL_WITH_LF);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(ListenerToDialerMessage::ProtocolAck { name }) => {
|
||||
if !name.as_ref().starts_with(b"/") {
|
||||
return Err(MultistreamSelectError::WrongProtocolName)
|
||||
}
|
||||
let len = name.as_ref().len() + 1; // + 1 for \n
|
||||
if len > std::u16::MAX as usize {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
|
||||
}
|
||||
Uvi::<usize>::default().encode(len, dest)?;
|
||||
dest.reserve(len);
|
||||
dest.put(name.as_ref());
|
||||
dest.put(&b"\n"[..]);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(ListenerToDialerMessage::ProtocolsListResponse { list }) => {
|
||||
let mut buf = encode::usize_buffer();
|
||||
let mut out_msg = Vec::from(encode::usize(list.len(), &mut buf));
|
||||
for e in &list {
|
||||
if e.as_ref().len() + 1 > std::u16::MAX as usize {
|
||||
return Err(io::Error::new(io::ErrorKind::InvalidData, "name too long").into())
|
||||
}
|
||||
out_msg.extend(encode::usize(e.as_ref().len() + 1, &mut buf)); // +1 for '\n'
|
||||
out_msg.extend_from_slice(e.as_ref());
|
||||
out_msg.push(b'\n')
|
||||
}
|
||||
let len = encode::usize(out_msg.len(), &mut buf);
|
||||
dest.reserve(len.len() + out_msg.len());
|
||||
dest.put(len);
|
||||
dest.put(out_msg);
|
||||
Ok(())
|
||||
}
|
||||
Message::Body(ListenerToDialerMessage::NotAvailable) => {
|
||||
Uvi::<usize>::default().encode(3, dest)?;
|
||||
dest.reserve(3);
|
||||
dest.put(&b"na\n"[..]);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
@ -250,7 +266,7 @@ mod tests {
|
||||
|
||||
let client = TcpStream::connect(&listener_addr)
|
||||
.from_err()
|
||||
.and_then(move |stream| Dialer::new(stream));
|
||||
.and_then(move |stream| Dialer::<_, Bytes>::new(stream));
|
||||
|
||||
let mut rt = Runtime::new().unwrap();
|
||||
match rt.block_on(server.join(client)) {
|
||||
|
@ -20,8 +20,6 @@
|
||||
|
||||
//! Contains lower-level structs to handle the multistream protocol.
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
mod dialer;
|
||||
mod error;
|
||||
mod listener;
|
||||
@ -34,14 +32,14 @@ pub use self::listener::{Listener, ListenerFuture};
|
||||
|
||||
/// Message sent from the dialer to the listener.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum DialerToListenerMessage {
|
||||
pub enum DialerToListenerMessage<N> {
|
||||
/// The dialer wants us to use a protocol.
|
||||
///
|
||||
/// If this is accepted (by receiving back a `ProtocolAck`), then we immediately start
|
||||
/// communicating in the new protocol.
|
||||
ProtocolRequest {
|
||||
/// Name of the protocol.
|
||||
name: Bytes,
|
||||
name: N
|
||||
},
|
||||
|
||||
/// The dialer requested the list of protocols that the listener supports.
|
||||
@ -50,10 +48,10 @@ pub enum DialerToListenerMessage {
|
||||
|
||||
/// Message sent from the listener to the dialer.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ListenerToDialerMessage {
|
||||
pub enum ListenerToDialerMessage<N> {
|
||||
/// The protocol requested by the dialer is accepted. The socket immediately starts using the
|
||||
/// new protocol.
|
||||
ProtocolAck { name: Bytes },
|
||||
ProtocolAck { name: N },
|
||||
|
||||
/// The protocol requested by the dialer is not supported or available.
|
||||
NotAvailable,
|
||||
@ -62,6 +60,7 @@ pub enum ListenerToDialerMessage {
|
||||
ProtocolsListResponse {
|
||||
/// The list of protocols.
|
||||
// TODO: use some sort of iterator
|
||||
list: Vec<Bytes>,
|
||||
list: Vec<N>,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -22,15 +22,13 @@
|
||||
|
||||
#![cfg(test)]
|
||||
|
||||
use crate::ProtocolChoiceError;
|
||||
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
|
||||
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage};
|
||||
use crate::{dialer_select_proto, listener_select_proto};
|
||||
use futures::prelude::*;
|
||||
use tokio::runtime::current_thread::Runtime;
|
||||
use tokio_tcp::{TcpListener, TcpStream};
|
||||
use bytes::Bytes;
|
||||
use crate::dialer_select::{dialer_select_proto_parallel, dialer_select_proto_serial};
|
||||
use futures::Future;
|
||||
use futures::{Sink, Stream};
|
||||
use crate::protocol::{Dialer, DialerToListenerMessage, Listener, ListenerToDialerMessage};
|
||||
use crate::ProtocolChoiceError;
|
||||
use crate::{dialer_select_proto, listener_select_proto};
|
||||
|
||||
/// Holds a `Vec` and satifies the iterator requirements of `listener_select_proto`.
|
||||
struct VecRefIntoIter<T>(Vec<T>);
|
||||
@ -68,7 +66,7 @@ fn negotiate_with_self_succeeds() {
|
||||
.from_err()
|
||||
.and_then(move |stream| Dialer::new(stream))
|
||||
.and_then(move |dialer| {
|
||||
let p = Bytes::from("/hello/1.0.0");
|
||||
let p = b"/hello/1.0.0";
|
||||
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
|
||||
})
|
||||
.and_then(move |dialer| dialer.into_future().map_err(|(e, _)| e))
|
||||
|
Reference in New Issue
Block a user