multistream-select: use FramedWrite from tokio-codec. (#539)

This commit is contained in:
Toralf Wittner
2018-10-10 09:16:21 +02:00
committed by GitHub
parent 68632ce26b
commit fd4ae72f8c
5 changed files with 61 additions and 67 deletions

View File

@ -9,6 +9,7 @@ bytes = "0.4"
futures = { version = "0.1" } futures = { version = "0.1" }
log = "0.4" log = "0.4"
smallvec = "0.6" smallvec = "0.6"
tokio-codec = "0.1"
tokio-io = "0.1" tokio-io = "0.1"
unsigned-varint = { version = "0.2.1", features = ["codec"] } unsigned-varint = { version = "0.2.1", features = ["codec"] }

View File

@ -18,28 +18,21 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE. // DEALINGS IN THE SOFTWARE.
//! Alternative implementation for `tokio_io::codec::length_delimited::FramedRead` with an
//! additional property: the `into_inner()` method is guarateed not to drop any data.
//!
//! Also has the length field length hardcoded.
//!
//! We purposely only support a frame length of under 64kiB. Frames most consist in a short
//! protocol name, which is highly unlikely to be more than 64kiB long.
use futures::{Async, Poll, Sink, StartSend, Stream}; use futures::{Async, Poll, Sink, StartSend, Stream};
use smallvec::SmallVec; use smallvec::SmallVec;
use std::io::{Error as IoError, ErrorKind as IoErrorKind}; use std::{io::{Error as IoError, ErrorKind as IoErrorKind}, marker::PhantomData, u16};
use std::marker::PhantomData; use tokio_codec::FramedWrite;
use tokio_io::AsyncRead; use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec::UviBytes;
/// Wraps around a `AsyncRead` and implements `Stream`. /// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
/// and write unsigned-varint prefixed frames.
/// ///
/// Also implements `Sink` if the inner object implements `Sink`, for convenience. /// We purposely only support a frame length of under 64kiB. Frames mostly consist
/// /// in a short protocol name, which is highly unlikely to be more than 64kiB long.
/// The `I` generic indicates the type of data that needs to be produced by the `Stream`. pub struct LengthDelimited<I, S> {
pub struct LengthDelimitedFramedRead<I, S> {
// The inner socket where data is pulled from. // The inner socket where data is pulled from.
inner: S, inner: FramedWrite<S, UviBytes>,
// Intermediary buffer where we put either the length of the next frame of data, or the frame // Intermediary buffer where we put either the length of the next frame of data, or the frame
// of data itself before it is returned. // of data itself before it is returned.
// Must always contain enough space to read data from `inner`. // Must always contain enough space to read data from `inner`.
@ -59,10 +52,16 @@ enum State {
ReadingData { frame_len: u16 }, ReadingData { frame_len: u16 },
} }
impl<I, S> LengthDelimitedFramedRead<I, S> { impl<I, S> LengthDelimited<I, S>
pub fn new(inner: S) -> LengthDelimitedFramedRead<I, S> { where
LengthDelimitedFramedRead { S: AsyncWrite
inner, {
pub fn new(inner: S) -> LengthDelimited<I, S> {
let mut encoder = UviBytes::default();
encoder.set_max_len(usize::from(u16::MAX));
LengthDelimited {
inner: FramedWrite::new(inner, encoder),
internal_buffer: { internal_buffer: {
let mut v = SmallVec::new(); let mut v = SmallVec::new();
v.push(0); v.push(0);
@ -74,7 +73,7 @@ impl<I, S> LengthDelimitedFramedRead<I, S> {
} }
} }
/// Destroys the `LengthDelimitedFramedRead` and returns the underlying socket. /// Destroys the `LengthDelimited` and returns the underlying socket.
/// ///
/// Contrary to its equivalent `tokio_io::codec::length_delimited::FramedRead`, this method is /// Contrary to its equivalent `tokio_io::codec::length_delimited::FramedRead`, this method is
/// guaranteed not to skip any data from the socket. /// guaranteed not to skip any data from the socket.
@ -89,11 +88,11 @@ impl<I, S> LengthDelimitedFramedRead<I, S> {
pub fn into_inner(self) -> S { pub fn into_inner(self) -> S {
assert_eq!(self.state, State::ReadingLength); assert_eq!(self.state, State::ReadingLength);
assert_eq!(self.internal_buffer_pos, 0); assert_eq!(self.internal_buffer_pos, 0);
self.inner self.inner.into_inner()
} }
} }
impl<I, S> Stream for LengthDelimitedFramedRead<I, S> impl<I, S> Stream for LengthDelimited<I, S>
where where
S: AsyncRead, S: AsyncRead,
I: for<'r> From<&'r [u8]>, I: for<'r> From<&'r [u8]>,
@ -109,6 +108,7 @@ where
match self.state { match self.state {
State::ReadingLength => { State::ReadingLength => {
match self.inner match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..]) .read(&mut self.internal_buffer[self.internal_buffer_pos..])
{ {
Ok(0) => { Ok(0) => {
@ -166,6 +166,7 @@ where
State::ReadingData { frame_len } => { State::ReadingData { frame_len } => {
match self.inner match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..]) .read(&mut self.internal_buffer[self.internal_buffer_pos..])
{ {
Ok(0) => { Ok(0) => {
@ -195,12 +196,12 @@ where
} }
} }
impl<I, S> Sink for LengthDelimitedFramedRead<I, S> impl<I, S> Sink for LengthDelimited<I, S>
where where
S: Sink, S: AsyncWrite
{ {
type SinkItem = S::SinkItem; type SinkItem = <FramedWrite<S, UviBytes> as Sink>::SinkItem;
type SinkError = S::SinkError; type SinkError = <FramedWrite<S, UviBytes> as Sink>::SinkError;
#[inline] #[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> { fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
@ -236,14 +237,14 @@ fn decode_length_prefix(buf: &[u8]) -> u16 {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use futures::{Future, Stream}; use futures::{Future, Stream};
use length_delimited::LengthDelimitedFramedRead; use length_delimited::LengthDelimited;
use std::io::Cursor; use std::io::Cursor;
use std::io::ErrorKind; use std::io::ErrorKind;
#[test] #[test]
fn basic_read() { fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4]; let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap(); let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
@ -252,7 +253,7 @@ mod tests {
#[test] #[test]
fn basic_read_two() { fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7]; let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap(); let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]); assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
@ -265,7 +266,7 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>(); let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8]; let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter()); data.extend(frame.clone().into_iter());
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed let recved = framed
.into_future() .into_future()
@ -280,7 +281,7 @@ mod tests {
fn packet_len_too_long() { fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1]; let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0)); data.extend((0..16513).map(|_| 0));
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed let recved = framed
.into_future() .into_future()
@ -296,7 +297,7 @@ mod tests {
#[test] #[test]
fn empty_frames() { fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7]; let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap(); let recved = framed.collect().wait().unwrap();
assert_eq!( assert_eq!(
@ -314,7 +315,7 @@ mod tests {
#[test] #[test]
fn unexpected_eof_in_len() { fn unexpected_eof_in_len() {
let data = vec![0x89]; let data = vec![0x89];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait(); let recved = framed.collect().wait();
match recved { match recved {
@ -326,7 +327,7 @@ mod tests {
#[test] #[test]
fn unexpected_eof_in_data() { fn unexpected_eof_in_data() {
let data = vec![5]; let data = vec![5];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait(); let recved = framed.collect().wait();
match recved { match recved {
@ -338,7 +339,7 @@ mod tests {
#[test] #[test]
fn unexpected_eof_in_data2() { fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7]; let data = vec![5, 9, 8, 7];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data)); let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait(); let recved = framed.collect().wait();
match recved { match recved {
@ -347,3 +348,4 @@ mod tests {
} }
} }
} }

View File

@ -118,6 +118,7 @@ extern crate futures;
#[macro_use] #[macro_use]
extern crate log; extern crate log;
extern crate smallvec; extern crate smallvec;
extern crate tokio_codec;
extern crate tokio_io; extern crate tokio_io;
extern crate unsigned_varint; extern crate unsigned_varint;

View File

@ -20,15 +20,13 @@
//! Contains the `Dialer` wrapper, which allows raw communications with a listener. //! Contains the `Dialer` wrapper, which allows raw communications with a listener.
use bytes::{Bytes, BytesMut}; use bytes::Bytes;
use futures::{prelude::*, sink, Async, AsyncSink, StartSend}; use futures::{prelude::*, sink, Async, AsyncSink, StartSend};
use length_delimited::LengthDelimitedFramedRead; use length_delimited::LengthDelimited;
use protocol::DialerToListenerMessage; use protocol::DialerToListenerMessage;
use protocol::ListenerToDialerMessage; use protocol::ListenerToDialerMessage;
use protocol::MultistreamSelectError; use protocol::MultistreamSelectError;
use protocol::MULTISTREAM_PROTOCOL_WITH_LF; use protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use tokio_io::codec::length_delimited::Builder as LengthDelimitedBuilder;
use tokio_io::codec::length_delimited::FramedWrite as LengthDelimitedFramedWrite;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::decode; use unsigned_varint::decode;
@ -36,7 +34,7 @@ use unsigned_varint::decode;
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the dialer's side. Produces and /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the dialer's side. Produces and
/// accepts messages. /// accepts messages.
pub struct Dialer<R> { pub struct Dialer<R> {
inner: LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<R, BytesMut>>, inner: LengthDelimited<Bytes, R>,
handshake_finished: bool, handshake_finished: bool,
} }
@ -47,17 +45,16 @@ where
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
/// future returns a `Dialer`. /// future returns a `Dialer`.
pub fn new(inner: R) -> DialerFuture<R> { pub fn new(inner: R) -> DialerFuture<R> {
let write = LengthDelimitedBuilder::new().length_field_length(1).new_write(inner); let sender = LengthDelimited::new(inner);
let sender = LengthDelimitedFramedRead::new(write);
DialerFuture { DialerFuture {
inner: sender.send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF)) inner: sender.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF))
} }
} }
/// Grants back the socket. Typically used after a `ProtocolAck` has been received. /// Grants back the socket. Typically used after a `ProtocolAck` has been received.
#[inline] #[inline]
pub fn into_inner(self) -> R { pub fn into_inner(self) -> R {
self.inner.into_inner().into_inner() self.inner.into_inner()
} }
} }
@ -74,14 +71,13 @@ where
if !name.starts_with(b"/") { if !name.starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName); return Err(MultistreamSelectError::WrongProtocolName);
} }
let mut protocol = BytesMut::from(name); let mut protocol = Bytes::from(name);
protocol.extend_from_slice(&[b'\n']); protocol.extend_from_slice(&[b'\n']);
match self.inner.start_send(protocol) { match self.inner.start_send(protocol) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(mut protocol)) => { Ok(AsyncSink::NotReady(mut protocol)) => {
let protocol_len = protocol.len(); let protocol_len = protocol.len();
protocol.truncate(protocol_len - 1); protocol.truncate(protocol_len - 1);
let protocol = protocol.freeze();
Ok(AsyncSink::NotReady( Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolRequest { name: protocol }, DialerToListenerMessage::ProtocolRequest { name: protocol },
)) ))
@ -91,7 +87,7 @@ where
} }
DialerToListenerMessage::ProtocolsListRequest => { DialerToListenerMessage::ProtocolsListRequest => {
match self.inner.start_send(BytesMut::from(&b"ls\n"[..])) { match self.inner.start_send(Bytes::from(&b"ls\n"[..])) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady( Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolsListRequest, DialerToListenerMessage::ProtocolsListRequest,
@ -171,7 +167,7 @@ where
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`. /// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
pub struct DialerFuture<T: AsyncWrite> { pub struct DialerFuture<T: AsyncWrite> {
inner: sink::Send<LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<T, BytesMut>>> inner: sink::Send<LengthDelimited<Bytes, T>>
} }
impl<T: AsyncWrite> Future for DialerFuture<T> { impl<T: AsyncWrite> Future for DialerFuture<T> {

View File

@ -20,16 +20,14 @@
//! Contains the `Listener` wrapper, which allows raw communications with a dialer. //! Contains the `Listener` wrapper, which allows raw communications with a dialer.
use bytes::{Bytes, BytesMut}; use bytes::Bytes;
use futures::{Async, AsyncSink, prelude::*, sink, stream::StreamFuture}; use futures::{Async, AsyncSink, prelude::*, sink, stream::StreamFuture};
use length_delimited::LengthDelimitedFramedRead; use length_delimited::LengthDelimited;
use protocol::DialerToListenerMessage; use protocol::DialerToListenerMessage;
use protocol::ListenerToDialerMessage; use protocol::ListenerToDialerMessage;
use protocol::MultistreamSelectError; use protocol::MultistreamSelectError;
use protocol::MULTISTREAM_PROTOCOL_WITH_LF; use protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use std::mem; use std::mem;
use tokio_io::codec::length_delimited::Builder as LengthDelimitedBuilder;
use tokio_io::codec::length_delimited::FramedWrite as LengthDelimitedFramedWrite;
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::encode; use unsigned_varint::encode;
@ -37,7 +35,7 @@ use unsigned_varint::encode;
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and /// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
/// accepts messages. /// accepts messages.
pub struct Listener<R> { pub struct Listener<R> {
inner: LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<R, BytesMut>>, inner: LengthDelimited<Bytes, R>
} }
impl<R> Listener<R> impl<R> Listener<R>
@ -47,10 +45,7 @@ where
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the /// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
/// future returns a `Listener`. /// future returns a `Listener`.
pub fn new(inner: R) -> ListenerFuture<R> { pub fn new(inner: R) -> ListenerFuture<R> {
let write = LengthDelimitedBuilder::new() let inner = LengthDelimited::new(inner);
.length_field_length(1)
.new_write(inner);
let inner = LengthDelimitedFramedRead::<Bytes, _>::new(write);
ListenerFuture { ListenerFuture {
inner: ListenerFutureState::Await { inner: inner.into_future() } inner: ListenerFutureState::Await { inner: inner.into_future() }
} }
@ -60,7 +55,7 @@ where
/// `ProtocolAck` has been sent back. /// `ProtocolAck` has been sent back.
#[inline] #[inline]
pub fn into_inner(self) -> R { pub fn into_inner(self) -> R {
self.inner.into_inner().into_inner() self.inner.into_inner()
} }
} }
@ -79,14 +74,13 @@ where
debug!("invalid protocol name {:?}", name); debug!("invalid protocol name {:?}", name);
return Err(MultistreamSelectError::WrongProtocolName); return Err(MultistreamSelectError::WrongProtocolName);
} }
let mut protocol = BytesMut::from(name); let mut protocol = Bytes::from(name);
protocol.extend_from_slice(&[b'\n']); protocol.extend_from_slice(&[b'\n']);
match self.inner.start_send(protocol) { match self.inner.start_send(protocol) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(mut protocol)) => { Ok(AsyncSink::NotReady(mut protocol)) => {
let protocol_len = protocol.len(); let protocol_len = protocol.len();
protocol.truncate(protocol_len - 1); protocol.truncate(protocol_len - 1);
let protocol = protocol.freeze();
Ok(AsyncSink::NotReady(ListenerToDialerMessage::ProtocolAck { Ok(AsyncSink::NotReady(ListenerToDialerMessage::ProtocolAck {
name: protocol, name: protocol,
})) }))
@ -96,7 +90,7 @@ where
} }
ListenerToDialerMessage::NotAvailable => { ListenerToDialerMessage::NotAvailable => {
match self.inner.start_send(BytesMut::from(&b"na\n"[..])) { match self.inner.start_send(Bytes::from(&b"na\n"[..])) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => { Ok(AsyncSink::NotReady(_)) => {
Ok(AsyncSink::NotReady(ListenerToDialerMessage::NotAvailable)) Ok(AsyncSink::NotReady(ListenerToDialerMessage::NotAvailable))
@ -116,7 +110,7 @@ where
out_msg.extend(iter::once(b'\n')); out_msg.extend(iter::once(b'\n'));
} }
match self.inner.start_send(BytesMut::from(out_msg)) { match self.inner.start_send(Bytes::from(out_msg)) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready), Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => { Ok(AsyncSink::NotReady(_)) => {
let m = ListenerToDialerMessage::ProtocolsListResponse { list }; let m = ListenerToDialerMessage::ProtocolsListResponse { list };
@ -179,10 +173,10 @@ pub struct ListenerFuture<T: AsyncRead + AsyncWrite> {
enum ListenerFutureState<T: AsyncRead + AsyncWrite> { enum ListenerFutureState<T: AsyncRead + AsyncWrite> {
Await { Await {
inner: StreamFuture<LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<T, BytesMut>>> inner: StreamFuture<LengthDelimited<Bytes, T>>
}, },
Reply { Reply {
sender: sink::Send<LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<T, BytesMut>>> sender: sink::Send<LengthDelimited<Bytes, T>>
}, },
Undefined Undefined
} }
@ -209,7 +203,7 @@ impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
return Err(MultistreamSelectError::FailedHandshake) return Err(MultistreamSelectError::FailedHandshake)
} }
trace!("sending back /multistream/<version> to finish the handshake"); trace!("sending back /multistream/<version> to finish the handshake");
let sender = socket.send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF)); let sender = socket.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF));
self.inner = ListenerFutureState::Reply { sender } self.inner = ListenerFutureState::Reply { sender }
} }
ListenerFutureState::Reply { mut sender } => { ListenerFutureState::Reply { mut sender } => {