multistream-select: use FramedWrite from tokio-codec. (#539)

This commit is contained in:
Toralf Wittner
2018-10-10 09:16:21 +02:00
committed by GitHub
parent 68632ce26b
commit fd4ae72f8c
5 changed files with 61 additions and 67 deletions

View File

@ -9,6 +9,7 @@ bytes = "0.4"
futures = { version = "0.1" }
log = "0.4"
smallvec = "0.6"
tokio-codec = "0.1"
tokio-io = "0.1"
unsigned-varint = { version = "0.2.1", features = ["codec"] }

View File

@ -18,28 +18,21 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
//! Alternative implementation for `tokio_io::codec::length_delimited::FramedRead` with an
//! additional property: the `into_inner()` method is guarateed not to drop any data.
//!
//! Also has the length field length hardcoded.
//!
//! We purposely only support a frame length of under 64kiB. Frames most consist in a short
//! protocol name, which is highly unlikely to be more than 64kiB long.
use futures::{Async, Poll, Sink, StartSend, Stream};
use smallvec::SmallVec;
use std::io::{Error as IoError, ErrorKind as IoErrorKind};
use std::marker::PhantomData;
use tokio_io::AsyncRead;
use std::{io::{Error as IoError, ErrorKind as IoErrorKind}, marker::PhantomData, u16};
use tokio_codec::FramedWrite;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec::UviBytes;
/// Wraps around a `AsyncRead` and implements `Stream`.
/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
/// and write unsigned-varint prefixed frames.
///
/// Also implements `Sink` if the inner object implements `Sink`, for convenience.
///
/// The `I` generic indicates the type of data that needs to be produced by the `Stream`.
pub struct LengthDelimitedFramedRead<I, S> {
/// We purposely only support a frame length of under 64kiB. Frames mostly consist
/// in a short protocol name, which is highly unlikely to be more than 64kiB long.
pub struct LengthDelimited<I, S> {
// The inner socket where data is pulled from.
inner: S,
inner: FramedWrite<S, UviBytes>,
// Intermediary buffer where we put either the length of the next frame of data, or the frame
// of data itself before it is returned.
// Must always contain enough space to read data from `inner`.
@ -59,10 +52,16 @@ enum State {
ReadingData { frame_len: u16 },
}
impl<I, S> LengthDelimitedFramedRead<I, S> {
pub fn new(inner: S) -> LengthDelimitedFramedRead<I, S> {
LengthDelimitedFramedRead {
inner,
impl<I, S> LengthDelimited<I, S>
where
S: AsyncWrite
{
pub fn new(inner: S) -> LengthDelimited<I, S> {
let mut encoder = UviBytes::default();
encoder.set_max_len(usize::from(u16::MAX));
LengthDelimited {
inner: FramedWrite::new(inner, encoder),
internal_buffer: {
let mut v = SmallVec::new();
v.push(0);
@ -74,7 +73,7 @@ impl<I, S> LengthDelimitedFramedRead<I, S> {
}
}
/// Destroys the `LengthDelimitedFramedRead` and returns the underlying socket.
/// Destroys the `LengthDelimited` and returns the underlying socket.
///
/// Contrary to its equivalent `tokio_io::codec::length_delimited::FramedRead`, this method is
/// guaranteed not to skip any data from the socket.
@ -89,11 +88,11 @@ impl<I, S> LengthDelimitedFramedRead<I, S> {
pub fn into_inner(self) -> S {
assert_eq!(self.state, State::ReadingLength);
assert_eq!(self.internal_buffer_pos, 0);
self.inner
self.inner.into_inner()
}
}
impl<I, S> Stream for LengthDelimitedFramedRead<I, S>
impl<I, S> Stream for LengthDelimited<I, S>
where
S: AsyncRead,
I: for<'r> From<&'r [u8]>,
@ -109,6 +108,7 @@ where
match self.state {
State::ReadingLength => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
Ok(0) => {
@ -166,6 +166,7 @@ where
State::ReadingData { frame_len } => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
Ok(0) => {
@ -195,12 +196,12 @@ where
}
}
impl<I, S> Sink for LengthDelimitedFramedRead<I, S>
impl<I, S> Sink for LengthDelimited<I, S>
where
S: Sink,
S: AsyncWrite
{
type SinkItem = S::SinkItem;
type SinkError = S::SinkError;
type SinkItem = <FramedWrite<S, UviBytes> as Sink>::SinkItem;
type SinkError = <FramedWrite<S, UviBytes> as Sink>::SinkError;
#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
@ -236,14 +237,14 @@ fn decode_length_prefix(buf: &[u8]) -> u16 {
#[cfg(test)]
mod tests {
use futures::{Future, Stream};
use length_delimited::LengthDelimitedFramedRead;
use length_delimited::LengthDelimited;
use std::io::Cursor;
use std::io::ErrorKind;
#[test]
fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
@ -252,7 +253,7 @@ mod tests {
#[test]
fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
@ -265,7 +266,7 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter());
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed
.into_future()
@ -280,7 +281,7 @@ mod tests {
fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0));
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed
.into_future()
@ -296,7 +297,7 @@ mod tests {
#[test]
fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait().unwrap();
assert_eq!(
@ -314,7 +315,7 @@ mod tests {
#[test]
fn unexpected_eof_in_len() {
let data = vec![0x89];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait();
match recved {
@ -326,7 +327,7 @@ mod tests {
#[test]
fn unexpected_eof_in_data() {
let data = vec![5];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait();
match recved {
@ -338,7 +339,7 @@ mod tests {
#[test]
fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7];
let framed = LengthDelimitedFramedRead::<Vec<u8>, _>::new(Cursor::new(data));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));
let recved = framed.collect().wait();
match recved {
@ -347,3 +348,4 @@ mod tests {
}
}
}

View File

@ -118,6 +118,7 @@ extern crate futures;
#[macro_use]
extern crate log;
extern crate smallvec;
extern crate tokio_codec;
extern crate tokio_io;
extern crate unsigned_varint;

View File

@ -20,15 +20,13 @@
//! Contains the `Dialer` wrapper, which allows raw communications with a listener.
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use futures::{prelude::*, sink, Async, AsyncSink, StartSend};
use length_delimited::LengthDelimitedFramedRead;
use length_delimited::LengthDelimited;
use protocol::DialerToListenerMessage;
use protocol::ListenerToDialerMessage;
use protocol::MultistreamSelectError;
use protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use tokio_io::codec::length_delimited::Builder as LengthDelimitedBuilder;
use tokio_io::codec::length_delimited::FramedWrite as LengthDelimitedFramedWrite;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::decode;
@ -36,7 +34,7 @@ use unsigned_varint::decode;
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the dialer's side. Produces and
/// accepts messages.
pub struct Dialer<R> {
inner: LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<R, BytesMut>>,
inner: LengthDelimited<Bytes, R>,
handshake_finished: bool,
}
@ -47,17 +45,16 @@ where
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
/// future returns a `Dialer`.
pub fn new(inner: R) -> DialerFuture<R> {
let write = LengthDelimitedBuilder::new().length_field_length(1).new_write(inner);
let sender = LengthDelimitedFramedRead::new(write);
let sender = LengthDelimited::new(inner);
DialerFuture {
inner: sender.send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF))
inner: sender.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF))
}
}
/// Grants back the socket. Typically used after a `ProtocolAck` has been received.
#[inline]
pub fn into_inner(self) -> R {
self.inner.into_inner().into_inner()
self.inner.into_inner()
}
}
@ -74,14 +71,13 @@ where
if !name.starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName);
}
let mut protocol = BytesMut::from(name);
let mut protocol = Bytes::from(name);
protocol.extend_from_slice(&[b'\n']);
match self.inner.start_send(protocol) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(mut protocol)) => {
let protocol_len = protocol.len();
protocol.truncate(protocol_len - 1);
let protocol = protocol.freeze();
Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolRequest { name: protocol },
))
@ -91,7 +87,7 @@ where
}
DialerToListenerMessage::ProtocolsListRequest => {
match self.inner.start_send(BytesMut::from(&b"ls\n"[..])) {
match self.inner.start_send(Bytes::from(&b"ls\n"[..])) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolsListRequest,
@ -171,7 +167,7 @@ where
/// Future, returned by `Dialer::new`, which send the handshake and returns the actual `Dialer`.
pub struct DialerFuture<T: AsyncWrite> {
inner: sink::Send<LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<T, BytesMut>>>
inner: sink::Send<LengthDelimited<Bytes, T>>
}
impl<T: AsyncWrite> Future for DialerFuture<T> {

View File

@ -20,16 +20,14 @@
//! Contains the `Listener` wrapper, which allows raw communications with a dialer.
use bytes::{Bytes, BytesMut};
use bytes::Bytes;
use futures::{Async, AsyncSink, prelude::*, sink, stream::StreamFuture};
use length_delimited::LengthDelimitedFramedRead;
use length_delimited::LengthDelimited;
use protocol::DialerToListenerMessage;
use protocol::ListenerToDialerMessage;
use protocol::MultistreamSelectError;
use protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use std::mem;
use tokio_io::codec::length_delimited::Builder as LengthDelimitedBuilder;
use tokio_io::codec::length_delimited::FramedWrite as LengthDelimitedFramedWrite;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::encode;
@ -37,7 +35,7 @@ use unsigned_varint::encode;
/// Wraps around a `AsyncRead+AsyncWrite`. Assumes that we're on the listener's side. Produces and
/// accepts messages.
pub struct Listener<R> {
inner: LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<R, BytesMut>>,
inner: LengthDelimited<Bytes, R>
}
impl<R> Listener<R>
@ -47,10 +45,7 @@ where
/// Takes ownership of a socket and starts the handshake. If the handshake succeeds, the
/// future returns a `Listener`.
pub fn new(inner: R) -> ListenerFuture<R> {
let write = LengthDelimitedBuilder::new()
.length_field_length(1)
.new_write(inner);
let inner = LengthDelimitedFramedRead::<Bytes, _>::new(write);
let inner = LengthDelimited::new(inner);
ListenerFuture {
inner: ListenerFutureState::Await { inner: inner.into_future() }
}
@ -60,7 +55,7 @@ where
/// `ProtocolAck` has been sent back.
#[inline]
pub fn into_inner(self) -> R {
self.inner.into_inner().into_inner()
self.inner.into_inner()
}
}
@ -79,14 +74,13 @@ where
debug!("invalid protocol name {:?}", name);
return Err(MultistreamSelectError::WrongProtocolName);
}
let mut protocol = BytesMut::from(name);
let mut protocol = Bytes::from(name);
protocol.extend_from_slice(&[b'\n']);
match self.inner.start_send(protocol) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(mut protocol)) => {
let protocol_len = protocol.len();
protocol.truncate(protocol_len - 1);
let protocol = protocol.freeze();
Ok(AsyncSink::NotReady(ListenerToDialerMessage::ProtocolAck {
name: protocol,
}))
@ -96,7 +90,7 @@ where
}
ListenerToDialerMessage::NotAvailable => {
match self.inner.start_send(BytesMut::from(&b"na\n"[..])) {
match self.inner.start_send(Bytes::from(&b"na\n"[..])) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => {
Ok(AsyncSink::NotReady(ListenerToDialerMessage::NotAvailable))
@ -116,7 +110,7 @@ where
out_msg.extend(iter::once(b'\n'));
}
match self.inner.start_send(BytesMut::from(out_msg)) {
match self.inner.start_send(Bytes::from(out_msg)) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => {
let m = ListenerToDialerMessage::ProtocolsListResponse { list };
@ -179,10 +173,10 @@ pub struct ListenerFuture<T: AsyncRead + AsyncWrite> {
enum ListenerFutureState<T: AsyncRead + AsyncWrite> {
Await {
inner: StreamFuture<LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<T, BytesMut>>>
inner: StreamFuture<LengthDelimited<Bytes, T>>
},
Reply {
sender: sink::Send<LengthDelimitedFramedRead<Bytes, LengthDelimitedFramedWrite<T, BytesMut>>>
sender: sink::Send<LengthDelimited<Bytes, T>>
},
Undefined
}
@ -209,7 +203,7 @@ impl<T: AsyncRead + AsyncWrite> Future for ListenerFuture<T> {
return Err(MultistreamSelectError::FailedHandshake)
}
trace!("sending back /multistream/<version> to finish the handshake");
let sender = socket.send(BytesMut::from(MULTISTREAM_PROTOCOL_WITH_LF));
let sender = socket.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF));
self.inner = ListenerFutureState::Reply { sender }
}
ListenerFutureState::Reply { mut sender } => {