multistream-select: Less allocations. (#800)

This commit is contained in:
Toralf Wittner
2019-01-09 15:09:35 +01:00
committed by GitHub
parent aedf9c0c31
commit f1959252b7
9 changed files with 467 additions and 372 deletions

View File

@ -21,9 +21,13 @@
//! Contains the `dialer_select_proto` code, which allows selecting a protocol thanks to
//! `multistream-select` for the dialer.
use bytes::Bytes;
use futures::{future::Either, prelude::*, sink, stream::StreamFuture};
use crate::protocol::{Dialer, DialerFuture, DialerToListenerMessage, ListenerToDialerMessage};
use futures::{future::Either, prelude::*, stream::StreamFuture};
use crate::protocol::{
Dialer,
DialerFuture,
DialerToListenerMessage,
ListenerToDialerMessage
};
use log::trace;
use std::mem;
use tokio_io::{AsyncRead, AsyncWrite};
@ -44,7 +48,6 @@ pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPa
/// remote, and the protocol name that we passed (so that you don't have to clone the name). On
/// success, the function returns the identifier (of type `P`), plus the socket which now uses that
/// chosen protocol.
#[inline]
pub fn dialer_select_proto<R, I>(inner: R, protocols: I) -> DialerSelectFuture<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
@ -64,12 +67,13 @@ where
///
/// Same as `dialer_select_proto`. Tries protocols one by one. The iterator doesn't need to produce
/// match functions, because it's not needed.
pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I,) -> DialerSelectSeq<R, I>
pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I) -> DialerSelectSeq<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I: IntoIterator,
I::Item: AsRef<[u8]>
{
let protocols = protocols.into_iter();
DialerSelectSeq {
inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
}
@ -78,26 +82,37 @@ where
/// Future, returned by `dialer_select_proto_serial` which selects a protocol
/// and dialer sequentially.
pub struct DialerSelectSeq<R: AsyncRead + AsyncWrite, I: Iterator> {
pub struct DialerSelectSeq<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
inner: DialerSelectSeqState<R, I>
}
enum DialerSelectSeqState<R: AsyncRead + AsyncWrite, I: Iterator> {
enum DialerSelectSeqState<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
AwaitDialer {
dialer_fut: DialerFuture<R>,
dialer_fut: DialerFuture<R, I::Item>,
protocols: I
},
NextProtocol {
dialer: Dialer<R>,
dialer: Dialer<R, I::Item>,
proto_name: I::Item,
protocols: I
},
SendProtocol {
sender: sink::Send<Dialer<R>>,
FlushProtocol {
dialer: Dialer<R, I::Item>,
proto_name: I::Item,
protocols: I
},
AwaitProtocol {
stream: StreamFuture<Dialer<R>>,
stream: StreamFuture<Dialer<R, I::Item>>,
proto_name: I::Item,
protocols: I
},
@ -106,9 +121,9 @@ enum DialerSelectSeqState<R: AsyncRead + AsyncWrite, I: Iterator> {
impl<R, I> Future for DialerSelectSeq<R, I>
where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]> + Clone
{
type Item = (I::Item, R);
type Error = ProtocolChoiceError;
@ -116,7 +131,7 @@ where
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.inner, DialerSelectSeqState::Undefined) {
DialerSelectSeqState::AwaitDialer { mut dialer_fut, protocols } => {
DialerSelectSeqState::AwaitDialer { mut dialer_fut, mut protocols } => {
let dialer = match dialer_fut.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
@ -124,42 +139,57 @@ where
return Ok(Async::NotReady)
}
};
self.inner = DialerSelectSeqState::NextProtocol { dialer, protocols }
}
DialerSelectSeqState::NextProtocol { dialer, mut protocols } => {
let proto_name =
protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?;
let req = DialerToListenerMessage::ProtocolRequest {
name: Bytes::from(proto_name.as_ref())
};
trace!("sending {:?}", req);
let sender = dialer.send(req);
self.inner = DialerSelectSeqState::SendProtocol {
sender,
proto_name,
protocols
let proto_name = protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?;
self.inner = DialerSelectSeqState::NextProtocol {
dialer,
protocols,
proto_name
}
}
DialerSelectSeqState::SendProtocol { mut sender, proto_name, protocols } => {
let dialer = match sender.poll()? {
Async::Ready(d) => d,
DialerSelectSeqState::NextProtocol { mut dialer, protocols, proto_name } => {
trace!("sending {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest {
name: proto_name.clone()
};
match dialer.start_send(req)? {
AsyncSink::Ready => {
self.inner = DialerSelectSeqState::FlushProtocol {
dialer,
proto_name,
protocols
}
}
AsyncSink::NotReady(_) => {
self.inner = DialerSelectSeqState::NextProtocol {
dialer,
protocols,
proto_name
};
return Ok(Async::NotReady)
}
}
}
DialerSelectSeqState::FlushProtocol { mut dialer, proto_name, protocols } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
let stream = dialer.into_future();
self.inner = DialerSelectSeqState::AwaitProtocol {
stream,
proto_name,
protocols
}
}
Async::NotReady => {
self.inner = DialerSelectSeqState::SendProtocol {
sender,
self.inner = DialerSelectSeqState::FlushProtocol {
dialer,
proto_name,
protocols
};
return Ok(Async::NotReady)
}
};
let stream = dialer.into_future();
self.inner = DialerSelectSeqState::AwaitProtocol {
stream,
proto_name,
protocols
};
}
}
DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, protocols } => {
DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, mut protocols } => {
let (m, r) = match stream.poll() {
Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => {
@ -178,9 +208,15 @@ where
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, r.into_inner())))
},
}
ListenerToDialerMessage::NotAvailable => {
self.inner = DialerSelectSeqState::NextProtocol { dialer: r, protocols }
let proto_name = protocols.next()
.ok_or(ProtocolChoiceError::NoProtocolFound)?;
self.inner = DialerSelectSeqState::NextProtocol {
dialer: r,
protocols,
proto_name
}
}
_ => return Err(ProtocolChoiceError::UnexpectedMessage)
}
@ -192,17 +228,17 @@ where
}
}
/// Helps selecting a protocol amongst the ones supported.
///
/// Same as `dialer_select_proto`. Queries the list of supported protocols from the remote, then
/// chooses the most appropriate one.
pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I>
pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I::IntoIter>
where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
{
let protocols = protocols.into_iter();
DialerSelectPar {
inner: DialerSelectParState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
}
@ -212,29 +248,47 @@ where
/// Future, returned by `dialer_select_proto_parallel`, which selects a protocol and dialer in
/// parellel, by first requesting the liste of protocols supported by the remote endpoint and
/// then selecting the most appropriate one by applying a match predicate to the result.
pub struct DialerSelectPar<R: AsyncRead + AsyncWrite, I: Iterator> {
pub struct DialerSelectPar<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
inner: DialerSelectParState<R, I>
}
enum DialerSelectParState<R: AsyncRead + AsyncWrite, I: Iterator> {
enum DialerSelectParState<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
AwaitDialer {
dialer_fut: DialerFuture<R>,
dialer_fut: DialerFuture<R, I::Item>,
protocols: I
},
SendRequest {
sender: sink::Send<Dialer<R>>,
ProtocolList {
dialer: Dialer<R, I::Item>,
protocols: I
},
AwaitResponse {
stream: StreamFuture<Dialer<R>>,
FlushListRequest {
dialer: Dialer<R, I::Item>,
protocols: I
},
SendProtocol {
sender: sink::Send<Dialer<R>>,
AwaitListResponse {
stream: StreamFuture<Dialer<R, I::Item>>,
protocols: I,
},
Protocol {
dialer: Dialer<R, I::Item>,
proto_name: I::Item
},
FlushProtocol {
dialer: Dialer<R, I::Item>,
proto_name: I::Item
},
AwaitProtocol {
stream: StreamFuture<Dialer<R>>,
stream: StreamFuture<Dialer<R, I::Item>>,
proto_name: I::Item
},
Undefined
@ -242,9 +296,9 @@ enum DialerSelectParState<R: AsyncRead + AsyncWrite, I: Iterator> {
impl<R, I> Future for DialerSelectPar<R, I>
where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]> + Clone
{
type Item = (I::Item, R);
type Error = ProtocolChoiceError;
@ -253,42 +307,64 @@ where
loop {
match mem::replace(&mut self.inner, DialerSelectParState::Undefined) {
DialerSelectParState::AwaitDialer { mut dialer_fut, protocols } => {
let dialer = match dialer_fut.poll()? {
Async::Ready(d) => d,
match dialer_fut.poll()? {
Async::Ready(dialer) => {
self.inner = DialerSelectParState::ProtocolList { dialer, protocols }
}
Async::NotReady => {
self.inner = DialerSelectParState::AwaitDialer { dialer_fut, protocols };
return Ok(Async::NotReady)
}
};
trace!("requesting protocols list");
let sender = dialer.send(DialerToListenerMessage::ProtocolsListRequest);
self.inner = DialerSelectParState::SendRequest { sender, protocols };
}
}
DialerSelectParState::SendRequest { mut sender, protocols } => {
let dialer = match sender.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectParState::SendRequest { sender, protocols };
DialerSelectParState::ProtocolList { mut dialer, protocols } => {
trace!("requesting protocols list");
match dialer.start_send(DialerToListenerMessage::ProtocolsListRequest)? {
AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushListRequest {
dialer,
protocols
}
}
AsyncSink::NotReady(_) => {
self.inner = DialerSelectParState::ProtocolList { dialer, protocols };
return Ok(Async::NotReady)
}
};
let stream = dialer.into_future();
self.inner = DialerSelectParState::AwaitResponse { stream, protocols };
}
}
DialerSelectParState::AwaitResponse { mut stream, protocols } => {
let (m, d) = match stream.poll() {
DialerSelectParState::FlushListRequest { mut dialer, protocols } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
self.inner = DialerSelectParState::AwaitListResponse {
stream: dialer.into_future(),
protocols
}
}
Async::NotReady => {
self.inner = DialerSelectParState::FlushListRequest {
dialer,
protocols
};
return Ok(Async::NotReady)
}
}
}
DialerSelectParState::AwaitListResponse { mut stream, protocols } => {
let (resp, dialer) = match stream.poll() {
Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => {
self.inner = DialerSelectParState::AwaitResponse { stream, protocols };
self.inner = DialerSelectParState::AwaitListResponse { stream, protocols };
return Ok(Async::NotReady)
}
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
};
trace!("protocols list response: {:?}", m);
let list = match m {
Some(ListenerToDialerMessage::ProtocolsListResponse { list }) => list,
_ => return Err(ProtocolChoiceError::UnexpectedMessage),
};
trace!("protocols list response: {:?}", resp);
let list =
if let Some(ListenerToDialerMessage::ProtocolsListResponse { list }) = resp {
list
} else {
return Err(ProtocolChoiceError::UnexpectedMessage)
};
let mut found = None;
for local_name in protocols {
for remote_name in &list {
@ -302,47 +378,52 @@ where
}
}
let proto_name = found.ok_or(ProtocolChoiceError::NoProtocolFound)?;
trace!("sending {:?}", proto_name.as_ref());
let sender = d.send(DialerToListenerMessage::ProtocolRequest {
name: Bytes::from(proto_name.as_ref())
});
self.inner = DialerSelectParState::SendProtocol { sender, proto_name };
self.inner = DialerSelectParState::Protocol { dialer, proto_name }
}
DialerSelectParState::SendProtocol { mut sender, proto_name } => {
let dialer = match sender.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectParState::SendProtocol {
sender,
proto_name
};
DialerSelectParState::Protocol { mut dialer, proto_name } => {
trace!("requesting protocol: {:?}", proto_name.as_ref());
let req = DialerToListenerMessage::ProtocolRequest {
name: proto_name.clone()
};
match dialer.start_send(req)? {
AsyncSink::Ready => {
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name }
}
AsyncSink::NotReady(_) => {
self.inner = DialerSelectParState::Protocol { dialer, proto_name };
return Ok(Async::NotReady)
}
};
let stream = dialer.into_future();
self.inner = DialerSelectParState::AwaitProtocol {
stream,
proto_name
};
}
}
DialerSelectParState::FlushProtocol { mut dialer, proto_name } => {
match dialer.poll_complete()? {
Async::Ready(()) => {
self.inner = DialerSelectParState::AwaitProtocol {
stream: dialer.into_future(),
proto_name
}
}
Async::NotReady => {
self.inner = DialerSelectParState::FlushProtocol { dialer, proto_name };
return Ok(Async::NotReady)
}
}
}
DialerSelectParState::AwaitProtocol { mut stream, proto_name } => {
let (m, r) = match stream.poll() {
let (resp, dialer) = match stream.poll() {
Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => {
self.inner = DialerSelectParState::AwaitProtocol {
stream,
proto_name
};
self.inner = DialerSelectParState::AwaitProtocol { stream, proto_name };
return Ok(Async::NotReady)
}
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
};
trace!("received {:?}", m);
match m {
trace!("received {:?}", resp);
match resp {
Some(ListenerToDialerMessage::ProtocolAck { ref name })
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, r.into_inner())))
return Ok(Async::Ready((proto_name, dialer.into_inner())))
}
_ => return Err(ProtocolChoiceError::UnexpectedMessage)
}
@ -354,4 +435,3 @@ where
}
}