Add multiplex

This commit is contained in:
Vurich
2017-11-22 18:01:28 +01:00
parent 097666b09e
commit a4014bb08e
13 changed files with 1895 additions and 545 deletions

View File

@ -4,11 +4,7 @@ members = [
"libp2p-ping", "libp2p-ping",
"libp2p-secio", "libp2p-secio",
"libp2p-swarm", "libp2p-swarm",
"libp2p-transport",
"libp2p-host",
"libp2p-tcp-transport", "libp2p-tcp-transport",
"libp2p-stream-muxer",
"multihash",
"multistream-select", "multistream-select",
"datastore", "datastore",
"rw-stream-sink", "rw-stream-sink",

View File

@ -9,62 +9,109 @@
extern crate smallvec; extern crate smallvec;
use std::ops::{Deref, DerefMut, Drop}; use std::ops::Drop;
use std::mem::ManuallyDrop; use std::mem::ManuallyDrop;
use smallvec::Array; use smallvec::Array;
use owned_slice::OwnedSlice;
/// A slice that owns its elements, but not their storage. This is useful for things like /// A slice that owns its elements, but not their storage. This is useful for things like
/// `Vec::retain` and `CircularBuffer::pop_slice`, since these functions can return a slice but the /// `Vec::retain` and `CircularBuffer::pop_slice`, since these functions can return a slice but the
/// elements of these slices would be leaked after the slice goes out of scope. `OwnedSlice` simply /// elements of these slices would be leaked after the slice goes out of scope. `OwnedSlice` simply
/// manually drops all its elements when it goes out of scope. /// manually drops all its elements when it goes out of scope.
#[derive(Debug, Eq, PartialEq)] pub mod owned_slice {
pub struct OwnedSlice<'a, T: 'a>(&'a mut [T]); use std::ops::{Deref, DerefMut, Drop};
use std::mem::ManuallyDrop;
impl<'a, T: 'a> OwnedSlice<'a, T> { /// A slice that owns its elements, but not their storage. This is useful for things like
/// Construct an owned slice from a mutable slice pointer. /// `Vec::retain` and `CircularBuffer::pop_slice`, since these functions can return a slice but
/// /// the elements of these slices would be leaked after the slice goes out of scope. `OwnedSlice`
/// # Unsafety /// simply manually drops all its elements when it goes out of scope.
/// You must ensure that the memory pointed to by `inner` will not be accessible after the #[derive(Debug, Eq, PartialEq)]
/// lifetime of the `OwnedSlice`. pub struct OwnedSlice<'a, T: 'a>(&'a mut [T]);
pub unsafe fn new(inner: &'a mut [T]) -> Self {
OwnedSlice(inner) /// Owning iterator for `OwnedSlice`.
pub struct IntoIter<'a, T: 'a> {
slice: ManuallyDrop<OwnedSlice<'a, T>>,
index: usize,
} }
}
impl<'a, T> AsRef<[T]> for OwnedSlice<'a, T> { impl<'a, T> Iterator for IntoIter<'a, T> {
fn as_ref(&self) -> &[T] { type Item = T;
self.0
fn next(&mut self) -> Option<Self::Item> {
use std::ptr;
let index = self.index;
if index >= self.slice.len() {
return None;
}
self.index += 1;
unsafe { Some(ptr::read(&self.slice[index])) }
}
} }
}
impl<'a, T> AsMut<[T]> for OwnedSlice<'a, T> { impl<'a, T: 'a> IntoIterator for OwnedSlice<'a, T> {
fn as_mut(&mut self) -> &mut [T] { type Item = T;
self.0 type IntoIter = IntoIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
IntoIter {
slice: ManuallyDrop::new(self),
index: 0,
}
}
} }
}
impl<'a, T> Deref for OwnedSlice<'a, T> { impl<'a, T: 'a> OwnedSlice<'a, T> {
type Target = [T]; /// Construct an owned slice from a mutable slice pointer.
///
fn deref(&self) -> &Self::Target { /// # Unsafety
self.0 /// You must ensure that the memory pointed to by `inner` will not be accessible after the
/// lifetime of the `OwnedSlice`.
pub unsafe fn new(inner: &'a mut [T]) -> Self {
OwnedSlice(inner)
}
} }
}
impl<'a, T> DerefMut for OwnedSlice<'a, T> { impl<'a, T> AsRef<[T]> for OwnedSlice<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target { fn as_ref(&self) -> &[T] {
self.0 self.0
}
} }
}
impl<'a, T> Drop for OwnedSlice<'a, T> { impl<'a, T> AsMut<[T]> for OwnedSlice<'a, T> {
fn drop(&mut self) { fn as_mut(&mut self) -> &mut [T] {
use std::ptr; self.0
}
}
for element in self.iter_mut() { impl<'a, T> Deref for OwnedSlice<'a, T> {
unsafe { type Target = [T];
ptr::drop_in_place(element);
fn deref(&self) -> &Self::Target {
self.0
}
}
impl<'a, T> DerefMut for OwnedSlice<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.0
}
}
impl<'a, T> Drop for OwnedSlice<'a, T> {
fn drop(&mut self) {
use std::ptr;
for element in self.iter_mut() {
unsafe {
ptr::drop_in_place(element);
}
} }
} }
} }
@ -82,6 +129,12 @@ pub struct CircularBuffer<B: Array> {
len: usize, len: usize,
} }
impl<B: Array> Default for CircularBuffer<B> {
fn default() -> Self {
Self::new()
}
}
impl<B: Array> PartialEq for CircularBuffer<B> impl<B: Array> PartialEq for CircularBuffer<B>
where where
B::Item: PartialEq, B::Item: PartialEq,
@ -97,7 +150,7 @@ where
} }
} }
return true; true
} }
} }
@ -148,7 +201,9 @@ impl<B: Array> CircularBuffer<B> {
/// when the slice goes out of scope), if you're using non-`Drop` types you can use /// when the slice goes out of scope), if you're using non-`Drop` types you can use
/// `pop_slice_leaky`. /// `pop_slice_leaky`.
pub fn pop_slice(&mut self) -> Option<OwnedSlice<B::Item>> { pub fn pop_slice(&mut self) -> Option<OwnedSlice<B::Item>> {
self.pop_slice_leaky().map(OwnedSlice) self.pop_slice_leaky().map(
|x| unsafe { OwnedSlice::new(x) },
)
} }
/// Pop a slice containing the maximum possible contiguous number of elements. Since this buffer /// Pop a slice containing the maximum possible contiguous number of elements. Since this buffer
@ -357,18 +412,35 @@ impl<B: Array> CircularBuffer<B> {
} }
} }
/// Get a borrow to an element at an index unsafely (causes undefined behaviour if the index is /// Get a borrow to an element at an index unsafely (behaviour is undefined if the index is out
/// out of bounds). /// of bounds).
pub unsafe fn get_unchecked(&self, index: usize) -> &B::Item { pub unsafe fn get_unchecked(&self, index: usize) -> &B::Item {
use std::mem; &*self.buffer.ptr().offset(
mem::transmute(self.buffer.ptr().offset(
((index + self.start) % B::size()) as isize, ((index + self.start) % B::size()) as isize,
)) )
}
/// Get a mutable borrow to an element at an index safely (if the index is out of bounds, return
/// `None`).
pub fn get_mut(&mut self, index: usize) -> Option<&mut B::Item> {
if index < self.len {
unsafe { Some(self.get_unchecked_mut(index)) }
} else {
None
}
}
/// Get a mutable borrow to an element at an index unsafely (behaviour is undefined if the index
/// is out of bounds).
pub unsafe fn get_unchecked_mut(&mut self, index: usize) -> &mut B::Item {
&mut *self.buffer.ptr_mut().offset(
((index + self.start) % B::size()) as
isize,
)
} }
// This is not unsafe because it can only leak data, not cause uninit to be read. // This is not unsafe because it can only leak data, not cause uninit to be read.
fn advance(&mut self, by: usize) { pub fn advance(&mut self, by: usize) {
assert!(by <= self.len); assert!(by <= self.len);
self.start = (self.start + by) % B::size(); self.start = (self.start + by) % B::size();
@ -376,6 +448,39 @@ impl<B: Array> CircularBuffer<B> {
} }
} }
impl<B: Array> std::ops::Index<usize> for CircularBuffer<B> {
type Output = B::Item;
fn index(&self, index: usize) -> &Self::Output {
if let Some(out) = self.get(index) {
out
} else {
panic!(
"index out of bounds: the len is {} but the index is {}",
self.len,
index
);
}
}
}
impl<B: Array> std::ops::IndexMut<usize> for CircularBuffer<B> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
// We need to do this because borrowck isn't smart enough to understand enum variants
let len = self.len;
if let Some(out) = self.get_mut(index) {
return out;
} else {
panic!(
"index out of bounds: the len is {} but the index is {}",
len,
index
);
}
}
}
impl<B: Array> Drop for CircularBuffer<B> { impl<B: Array> Drop for CircularBuffer<B> {
fn drop(&mut self) { fn drop(&mut self) {
while self.pop_slice().is_some() {} while self.pop_slice().is_some() {}

View File

@ -1,8 +0,0 @@
[package]
name = "libp2p-stream-muxer"
version = "0.1.0"
authors = ["Vurich <jackefransham@gmail.com>"]
[dependencies]
futures = "0.1"
tokio-io = "0.1"

View File

@ -1,22 +0,0 @@
extern crate tokio_io;
extern crate futures;
use futures::stream::Stream;
use tokio_io::{AsyncRead, AsyncWrite};
pub trait StreamMuxer {
type Substream: AsyncRead + AsyncWrite;
type InboundSubstreams: Stream<Item = Self::Substream>;
type OutboundSubstreams: Stream<Item = Self::Substream>;
fn inbound(&mut self) -> Self::InboundSubstreams;
fn outbound(&mut self) -> Self::OutboundSubstreams;
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
assert_eq!(2 + 2, 4);
}
}

View File

@ -10,6 +10,8 @@ num-bigint = "0.1.40"
tokio-io = "0.1" tokio-io = "0.1"
futures = "0.1" futures = "0.1"
parking_lot = "0.4.8" parking_lot = "0.4.8"
libp2p-stream-muxer = { path = "../libp2p-stream-muxer" } arrayvec = "0.4.6"
rand = "0.3.17"
libp2p-swarm = { path = "../libp2p-swarm" }
varint = { path = "../varint-rs" } varint = { path = "../varint-rs" }
circular-buffer = { path = "../circular-buffer" } error-chain = "0.11.0"

3
multiplex-rs/README.md Normal file
View File

@ -0,0 +1,3 @@
# Multiplex
A Rust implementation of [multiplex](https://github.com/maxogden/multiplex).

145
multiplex-rs/src/header.rs Normal file
View File

@ -0,0 +1,145 @@
// Copyright 2017 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
const FLAG_BITS: usize = 3;
const FLAG_MASK: usize = (1usize << FLAG_BITS) - 1;
pub mod errors {
error_chain! {
errors {
ParseError
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum MultiplexEnd {
Initiator,
Receiver,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub struct MultiplexHeader {
pub packet_type: PacketType,
pub substream_id: u32,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum PacketType {
Open,
Close(MultiplexEnd),
Reset(MultiplexEnd),
Message(MultiplexEnd),
}
impl MultiplexHeader {
pub fn open(id: u32) -> Self {
MultiplexHeader {
substream_id: id,
packet_type: PacketType::Open,
}
}
pub fn close(id: u32, end: MultiplexEnd) -> Self {
MultiplexHeader {
substream_id: id,
packet_type: PacketType::Close(end),
}
}
pub fn reset(id: u32, end: MultiplexEnd) -> Self {
MultiplexHeader {
substream_id: id,
packet_type: PacketType::Reset(end),
}
}
pub fn message(id: u32, end: MultiplexEnd) -> Self {
MultiplexHeader {
substream_id: id,
packet_type: PacketType::Message(end),
}
}
// TODO: Use `u128` or another large integer type instead of bigint since we never use more than
// `pointer width + FLAG_BITS` bits and unconditionally allocating 1-3 `u32`s for that is
// ridiculous (especially since even for small numbers we have to allocate 1 `u32`).
// If this is the future and `BigUint` is better-optimised (maybe by using `Bytes`) then
// forget it.
pub fn parse(header: u64) -> Result<MultiplexHeader, errors::Error> {
use num_traits::cast::ToPrimitive;
let flags = header & FLAG_MASK as u64;
let substream_id = (header >> FLAG_BITS)
.to_u32()
.ok_or(errors::ErrorKind::ParseError)?;
// Yes, this is really how it works. No, I don't know why.
let packet_type = match flags {
0 => PacketType::Open,
1 => PacketType::Message(MultiplexEnd::Receiver),
2 => PacketType::Message(MultiplexEnd::Initiator),
3 => PacketType::Close(MultiplexEnd::Receiver),
4 => PacketType::Close(MultiplexEnd::Initiator),
5 => PacketType::Reset(MultiplexEnd::Receiver),
6 => PacketType::Reset(MultiplexEnd::Initiator),
_ => {
use std::io;
return Err(errors::Error::with_chain(
io::Error::new(
io::ErrorKind::Other,
format!("Unexpected packet type: {}", flags),
),
errors::ErrorKind::ParseError,
));
}
};
Ok(MultiplexHeader {
substream_id,
packet_type,
})
}
pub fn to_u64(&self) -> u64 {
let packet_type_id = match self.packet_type {
PacketType::Open => 0,
PacketType::Message(MultiplexEnd::Receiver) => 1,
PacketType::Message(MultiplexEnd::Initiator) => 2,
PacketType::Close(MultiplexEnd::Receiver) => 3,
PacketType::Close(MultiplexEnd::Initiator) => 4,
PacketType::Reset(MultiplexEnd::Receiver) => 5,
PacketType::Reset(MultiplexEnd::Initiator) => 6,
};
let substream_id = (self.substream_id as u64) << FLAG_BITS;
substream_id | packet_type_id
}
}

View File

@ -1,22 +1,56 @@
// Copyright 2017 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
extern crate arrayvec;
extern crate bytes; extern crate bytes;
#[macro_use]
extern crate error_chain;
extern crate futures; extern crate futures;
extern crate libp2p_stream_muxer; extern crate libp2p_swarm as swarm;
extern crate tokio_io;
extern crate varint;
extern crate num_bigint; extern crate num_bigint;
extern crate num_traits; extern crate num_traits;
extern crate parking_lot; extern crate parking_lot;
extern crate circular_buffer; extern crate rand;
extern crate tokio_io;
extern crate varint;
mod read;
mod write;
mod shared;
mod header;
use bytes::Bytes; use bytes::Bytes;
use circular_buffer::CircularBuffer; use futures::{Async, Future, Poll};
use futures::prelude::*; use futures::future::{self, FutureResult};
use libp2p_stream_muxer::StreamMuxer; use header::{MultiplexEnd, MultiplexHeader};
use swarm::muxing::StreamMuxer;
use swarm::ConnectionUpgrade;
use parking_lot::Mutex; use parking_lot::Mutex;
use std::collections::HashMap; use read::{read_stream, MultiplexReadState};
use shared::{buf_from_slice, ByteBuf, MultiplexShared};
use std::iter;
use std::io::{self, Read, Write}; use std::io::{self, Read, Write};
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{self, AtomicUsize};
use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::{AsyncRead, AsyncWrite};
use write::write_stream;
// So the multiplex is essentially a distributed finite state machine. // So the multiplex is essentially a distributed finite state machine.
// //
@ -30,338 +64,55 @@ use tokio_io::{AsyncRead, AsyncWrite};
// In the second state, the substream ID is known. Only this substream can progress until the packet // In the second state, the substream ID is known. Only this substream can progress until the packet
// is consumed. // is consumed.
/// Number of bits used for the metadata on multiplex packets
enum NextMultiplexState {
NewStream(usize),
ParsingMessageBody(usize),
Ignore,
}
enum MultiplexReadState {
Header { state: varint::DecoderState },
BodyLength {
state: varint::DecoderState,
next: NextMultiplexState,
},
NewStream {
substream_id: usize,
name: bytes::BytesMut,
remaining_bytes: usize,
},
ParsingMessageBody {
substream_id: usize,
remaining_bytes: usize,
},
Ignore { remaining_bytes: usize },
}
impl Default for MultiplexReadState {
fn default() -> Self {
MultiplexReadState::Header { state: Default::default() }
}
}
struct MultiplexWriteState {
buffer: CircularBuffer<[u8; 1024]>,
}
// TODO: Add writing. We should also add some form of "pending packet" so that we can always open at
// least one new substream. If this is stored on the substream itself then we can open
// infinite new substreams.
//
// When we've implemented writing, we should send the close message on `Substream` drop. This
// should probably be implemented with some kind of "pending close message" queue. The
// priority should go:
// 1. Open new stream messages
// 2. Regular messages
// 3. Close messages
// Since if we receive a message to a closed stream we just drop it anyway.
struct MultiplexShared<T> {
// We use `Option` in order to take ownership of heap allocations within `DecoderState` and
// `BytesMut`. If this is ever observably `None` then something has panicked or the underlying
// stream returned an error.
read_state: Option<MultiplexReadState>,
stream: T,
// true if the stream is open, false otherwise
open_streams: HashMap<usize, bool>,
// TODO: Should we use a version of this with a fixed size that doesn't allocate and return
// `WouldBlock` if it's full?
to_open: HashMap<usize, Bytes>,
}
pub struct Substream<T> { pub struct Substream<T> {
id: usize, id: u32,
end: MultiplexEnd,
name: Option<Bytes>, name: Option<Bytes>,
state: Arc<Mutex<MultiplexShared<T>>>, state: Arc<Mutex<MultiplexShared<T>>>,
buffer: Option<io::Cursor<ByteBuf>>,
} }
impl<T> Drop for Substream<T> { impl<T> Drop for Substream<T> {
fn drop(&mut self) { fn drop(&mut self) {
let mut lock = self.state.lock(); let mut lock = self.state.lock();
lock.open_streams.insert(self.id, false); lock.close_stream(self.id);
} }
} }
impl<T> Substream<T> { impl<T> Substream<T> {
fn new<B: Into<Option<Bytes>>>( fn new<B: Into<Option<Bytes>>>(
id: usize, id: u32,
end: MultiplexEnd,
name: B, name: B,
state: Arc<Mutex<MultiplexShared<T>>>, state: Arc<Mutex<MultiplexShared<T>>>,
) -> Self { ) -> Self {
let name = name.into(); let name = name.into();
Substream { id, name, state } Substream {
id,
end,
name,
state,
buffer: None,
}
} }
pub fn name(&self) -> Option<&Bytes> { pub fn name(&self) -> Option<&Bytes> {
self.name.as_ref() self.name.as_ref()
} }
}
/// This is unsafe because you must ensure that only the `AsyncRead` that was passed in is later pub fn id(&self) -> u32 {
/// used to write to the returned buffer. self.id
unsafe fn create_buffer_for<R: AsyncRead>(capacity: usize, inner: &R) -> bytes::BytesMut {
let mut buffer = bytes::BytesMut::with_capacity(capacity);
buffer.set_len(capacity);
inner.prepare_uninitialized_buffer(&mut buffer);
buffer
}
fn read_stream<'a, O: Into<Option<(usize, &'a mut [u8])>>, T: AsyncRead>(
lock: &mut MultiplexShared<T>,
stream_data: O,
) -> io::Result<usize> {
use num_traits::cast::ToPrimitive;
use MultiplexReadState::*;
let mut stream_data = stream_data.into();
let stream_has_been_gracefully_closed = stream_data
.as_ref()
.and_then(|&(id, _)| lock.open_streams.get(&id))
.map(|is_open| !is_open)
.unwrap_or(false);
let mut on_block: io::Result<usize> = if stream_has_been_gracefully_closed {
Ok(0)
} else {
Err(io::Error::from(io::ErrorKind::WouldBlock))
};
loop {
match lock.read_state.take().expect("Logic error or panic") {
Header { state: varint_state } => {
match varint_state.read(&mut lock.stream).map_err(|_| {
io::Error::from(io::ErrorKind::Other)
})? {
Ok(header) => {
let MultiplexHeader {
substream_id,
packet_type,
} = MultiplexHeader::parse(header).map_err(|_| {
io::Error::from(io::ErrorKind::Other)
})?;
match packet_type {
PacketType::Open => {
lock.read_state = Some(BodyLength {
state: Default::default(),
next: NextMultiplexState::NewStream(substream_id),
})
}
PacketType::Message(_) => {
lock.read_state = Some(BodyLength {
state: Default::default(),
next: NextMultiplexState::ParsingMessageBody(substream_id),
})
}
// NOTE: What's the difference between close and reset?
PacketType::Close(_) |
PacketType::Reset(_) => {
lock.read_state = Some(BodyLength {
state: Default::default(),
next: NextMultiplexState::Ignore,
});
lock.open_streams.remove(&substream_id);
}
}
}
Err(new_state) => {
lock.read_state = Some(Header { state: new_state });
return on_block;
}
}
}
BodyLength {
state: varint_state,
next,
} => {
match varint_state.read(&mut lock.stream).map_err(|_| {
io::Error::from(io::ErrorKind::Other)
})? {
Ok(length) => {
use NextMultiplexState::*;
let length = length.to_usize().ok_or(
io::Error::from(io::ErrorKind::Other),
)?;
lock.read_state = Some(match next {
Ignore => MultiplexReadState::Ignore { remaining_bytes: length },
NewStream(substream_id) => MultiplexReadState::NewStream {
// This is safe as long as we only use `lock.stream` to write to
// this field
name: unsafe { create_buffer_for(length, &lock.stream) },
remaining_bytes: length,
substream_id,
},
ParsingMessageBody(substream_id) => {
let is_open = lock.open_streams
.get(&substream_id)
.map(|is_open| *is_open)
.unwrap_or(false);
if is_open {
MultiplexReadState::ParsingMessageBody {
remaining_bytes: length,
substream_id,
}
} else {
MultiplexReadState::Ignore { remaining_bytes: length }
}
}
});
}
Err(new_state) => {
lock.read_state = Some(BodyLength {
state: new_state,
next,
});
return on_block;
}
}
}
NewStream {
substream_id,
mut name,
remaining_bytes,
} => {
if remaining_bytes == 0 {
lock.to_open.insert(substream_id, name.freeze());
lock.read_state = Some(Default::default());
} else {
let cursor_pos = name.len() - remaining_bytes;
let consumed = lock.stream.read(&mut name[cursor_pos..]);
match consumed {
Ok(consumed) => {
let new_remaining = remaining_bytes - consumed;
lock.read_state = Some(NewStream {
substream_id,
name,
remaining_bytes: new_remaining,
})
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
lock.read_state = Some(NewStream {
substream_id,
name,
remaining_bytes,
});
return on_block;
}
Err(other) => return Err(other),
}
}
}
ParsingMessageBody {
substream_id,
remaining_bytes,
} => {
if let Some((ref mut id, ref mut buf)) = stream_data {
use MultiplexReadState::*;
if substream_id == *id {
if remaining_bytes == 0 {
lock.read_state = Some(Default::default());
} else {
let read_result = {
let new_len = buf.len().min(remaining_bytes);
let slice = &mut buf[..new_len];
lock.stream.read(slice)
};
match read_result {
Ok(consumed) => {
let new_remaining = remaining_bytes - consumed;
lock.read_state = Some(ParsingMessageBody {
substream_id,
remaining_bytes: new_remaining,
});
on_block = Ok(on_block.unwrap_or(0) + consumed);
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
lock.read_state = Some(ParsingMessageBody {
substream_id,
remaining_bytes,
});
return on_block;
}
Err(other) => return Err(other),
}
}
} else {
lock.read_state = Some(ParsingMessageBody {
substream_id,
remaining_bytes,
});
// We cannot make progress here, another stream has to accept this packet
return on_block;
}
}
}
Ignore { mut remaining_bytes } => {
let mut ignore_buf: [u8; 256] = [0; 256];
loop {
if remaining_bytes == 0 {
lock.read_state = Some(Default::default());
} else {
let new_len = ignore_buf.len().min(remaining_bytes);
match lock.stream.read(&mut ignore_buf[..new_len]) {
Ok(consumed) => remaining_bytes -= consumed,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
lock.read_state = Some(Ignore { remaining_bytes });
return on_block;
}
Err(other) => return Err(other),
}
}
}
}
}
} }
} }
// TODO: We always zero the buffer, we should delegate to the inner stream. Maybe use a `RWLock` // TODO: We always zero the buffer, we should delegate to the inner stream.
// instead?
impl<T: AsyncRead> Read for Substream<T> { impl<T: AsyncRead> Read for Substream<T> {
// TODO: Is it wasteful to have all of our substreams try to make progress? Can we use an
// `AtomicBool` or `AtomicUsize` to limit the substreams that try to progress?
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let mut lock = match self.state.try_lock() { let mut lock = match self.state.try_lock() {
Some(lock) => lock, Some(lock) => lock,
None => return Err(io::Error::from(io::ErrorKind::WouldBlock)), None => return Err(io::ErrorKind::WouldBlock.into()),
}; };
read_stream(&mut lock, (self.id, buf)) read_stream(&mut lock, (self.id, buf))
@ -372,91 +123,50 @@ impl<T: AsyncRead> AsyncRead for Substream<T> {}
impl<T: AsyncWrite> Write for Substream<T> { impl<T: AsyncWrite> Write for Substream<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> { fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
unimplemented!() let mut lock = self.state.try_lock().ok_or(io::ErrorKind::WouldBlock)?;
let mut buffer = self.buffer
.take()
.unwrap_or_else(|| io::Cursor::new(buf_from_slice(buf)));
let out = write_stream(
&mut *lock,
write::WriteRequest::substream(MultiplexHeader::message(self.id, self.end)),
&mut buffer,
);
if buffer.position() < buffer.get_ref().len() as u64 {
self.buffer = Some(buffer);
}
out
} }
fn flush(&mut self) -> io::Result<()> { fn flush(&mut self) -> io::Result<()> {
unimplemented!() self.state
.try_lock()
.ok_or(io::ErrorKind::WouldBlock)?
.stream
.flush()
} }
} }
impl<T: AsyncWrite> AsyncWrite for Substream<T> { impl<T: AsyncWrite> AsyncWrite for Substream<T> {
fn shutdown(&mut self) -> Poll<(), io::Error> { fn shutdown(&mut self) -> Poll<(), io::Error> {
unimplemented!() Ok(Async::Ready(()))
} }
} }
struct ParseError; pub struct InboundFuture<T> {
end: MultiplexEnd,
enum MultiplexEnd {
Initiator,
Receiver,
}
struct MultiplexHeader {
pub packet_type: PacketType,
pub substream_id: usize,
}
enum PacketType {
Open,
Close(MultiplexEnd),
Reset(MultiplexEnd),
Message(MultiplexEnd),
}
impl MultiplexHeader {
// TODO: Use `u128` or another large integer type instead of bigint since we never use more than
// `pointer width + FLAG_BITS` bits and unconditionally allocating 1-3 `u32`s for that is
// ridiculous (especially since even for small numbers we have to allocate 1 `u32`).
// If this is the future and `BigUint` is better-optimised (maybe by using `Bytes`) then
// forget it.
fn parse(header: num_bigint::BigUint) -> Result<MultiplexHeader, ParseError> {
use num_traits::cast::ToPrimitive;
const FLAG_BITS: usize = 3;
// `&header` to make `>>` produce a new `BigUint` instead of consuming the old `BigUint`
let substream_id = ((&header) >> FLAG_BITS).to_usize().ok_or(ParseError)?;
let flag_mask = (2usize << FLAG_BITS) - 1;
let flags = header.to_usize().ok_or(ParseError)? & flag_mask;
// Yes, this is really how it works. No, I don't know why.
let packet_type = match flags {
0 => PacketType::Open,
1 => PacketType::Message(MultiplexEnd::Receiver),
2 => PacketType::Message(MultiplexEnd::Initiator),
3 => PacketType::Close(MultiplexEnd::Receiver),
4 => PacketType::Close(MultiplexEnd::Initiator),
5 => PacketType::Reset(MultiplexEnd::Receiver),
6 => PacketType::Reset(MultiplexEnd::Initiator),
_ => return Err(ParseError),
};
Ok(MultiplexHeader {
substream_id,
packet_type,
})
}
}
pub struct Multiplex<T> {
state: Arc<Mutex<MultiplexShared<T>>>, state: Arc<Mutex<MultiplexShared<T>>>,
} }
pub struct InboundStream<T> { impl<T: AsyncRead> Future for InboundFuture<T> {
state: Arc<Mutex<MultiplexShared<T>>>,
}
impl<T: AsyncRead> Stream for InboundStream<T> {
type Item = Substream<T>; type Item = Substream<T>;
type Error = io::Error; type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> { fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut lock = match self.state.try_lock() { let mut lock = match self.state.try_lock() {
Some(lock) => lock, Some(lock) => lock,
None => return Ok(Async::NotReady), None => return Ok(Async::NotReady),
@ -464,8 +174,8 @@ impl<T: AsyncRead> Stream for InboundStream<T> {
// Attempt to make progress, but don't block if we can't // Attempt to make progress, but don't block if we can't
match read_stream(&mut lock, None) { match read_stream(&mut lock, None) {
Ok(_) => (), Ok(_) => {}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (), Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {}
Err(err) => return Err(err), Err(err) => return Err(err),
} }
@ -479,30 +189,413 @@ impl<T: AsyncRead> Stream for InboundStream<T> {
"We just checked that this key exists and we have exclusive access to the map, QED", "We just checked that this key exists and we have exclusive access to the map, QED",
); );
Ok(Async::Ready( lock.open_stream(id);
Some(Substream::new(id, name, self.state.clone())),
)) Ok(Async::Ready(Substream::new(
id,
self.end,
name,
Arc::clone(&self.state),
)))
}
}
pub struct OutboundFuture<T> {
meta: Arc<MultiplexMetadata>,
current_id: Option<(io::Cursor<ByteBuf>, u32)>,
state: Arc<Mutex<MultiplexShared<T>>>,
}
impl<T> OutboundFuture<T> {
fn new(muxer: Multiplex<T>) -> Self {
OutboundFuture {
current_id: None,
meta: muxer.meta,
state: muxer.state,
}
}
}
fn nonce_to_id(id: usize, end: MultiplexEnd) -> u32 {
id as u32 * 2 + if end == MultiplexEnd::Initiator { 1 } else { 0 }
}
impl<T: AsyncWrite> Future for OutboundFuture<T> {
type Item = Substream<T>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut lock = match self.state.try_lock() {
Some(lock) => lock,
None => return Ok(Async::NotReady),
};
loop {
let (mut id_str, id) = self.current_id.take().unwrap_or_else(|| {
let next = nonce_to_id(
self.meta.nonce.fetch_add(1, atomic::Ordering::Relaxed),
self.meta.end,
);
(
io::Cursor::new(buf_from_slice(format!("{}", next).as_bytes())),
next as u32,
)
});
match write_stream(
&mut *lock,
write::WriteRequest::meta(MultiplexHeader::open(id)),
&mut id_str,
) {
Ok(_) => {
debug_assert!(id_str.position() <= id_str.get_ref().len() as u64);
if id_str.position() == id_str.get_ref().len() as u64 {
if lock.open_stream(id) {
return Ok(Async::Ready(Substream::new(
id,
self.meta.end,
Bytes::from(&id_str.get_ref()[..]),
Arc::clone(&self.state),
)));
}
} else {
self.current_id = Some((id_str, id));
return Ok(Async::NotReady);
}
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
self.current_id = Some((id_str, id));
return Ok(Async::NotReady);
}
Err(other) => return Err(other),
}
}
}
}
pub struct MultiplexMetadata {
nonce: AtomicUsize,
end: MultiplexEnd,
}
pub struct Multiplex<T> {
meta: Arc<MultiplexMetadata>,
state: Arc<Mutex<MultiplexShared<T>>>,
}
impl<T> Clone for Multiplex<T> {
fn clone(&self) -> Self {
Multiplex {
meta: self.meta.clone(),
state: self.state.clone(),
}
}
}
impl<T> Multiplex<T> {
pub fn new(stream: T, end: MultiplexEnd) -> Self {
Multiplex {
meta: Arc::new(MultiplexMetadata {
nonce: AtomicUsize::new(0),
end,
}),
state: Arc::new(Mutex::new(MultiplexShared::new(stream))),
}
}
pub fn dial(stream: T) -> Self {
Self::new(stream, MultiplexEnd::Initiator)
}
pub fn listen(stream: T) -> Self {
Self::new(stream, MultiplexEnd::Receiver)
} }
} }
impl<T: AsyncRead + AsyncWrite> StreamMuxer for Multiplex<T> { impl<T: AsyncRead + AsyncWrite> StreamMuxer for Multiplex<T> {
type Substream = Substream<T>; type Substream = Substream<T>;
type OutboundSubstreams = Box<Stream<Item = Self::Substream, Error = io::Error>>; type OutboundSubstream = OutboundFuture<T>;
type InboundSubstreams = InboundStream<T>; type InboundSubstream = InboundFuture<T>;
fn inbound(&mut self) -> Self::InboundSubstreams { fn inbound(self) -> Self::InboundSubstream {
InboundStream { state: self.state.clone() } InboundFuture {
state: Arc::clone(&self.state),
end: self.meta.end,
}
} }
fn outbound(&mut self) -> Self::OutboundSubstreams { fn outbound(self) -> Self::OutboundSubstream {
unimplemented!() OutboundFuture::new(self)
}
}
pub struct MultiplexConfig;
impl<C> ConnectionUpgrade<C> for MultiplexConfig
where
C: AsyncRead + AsyncWrite,
{
type Output = Multiplex<C>;
type Future = FutureResult<Multiplex<C>, io::Error>;
type UpgradeIdentifier = ();
type NamesIter = iter::Once<(Bytes, ())>;
#[inline]
fn upgrade(self, i: C, _: ()) -> Self::Future {
future::ok(Multiplex::dial(i))
}
#[inline]
fn protocol_names(&self) -> Self::NamesIter {
iter::once((Bytes::from("/mplex/6.7.0"), ()))
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*;
use std::io;
#[test] #[test]
fn it_works() { fn can_use_one_stream() {
assert_eq!(2 + 2, 4); let message = b"Hello, world!";
let stream = io::Cursor::new(Vec::new());
let mplex = Multiplex::dial(stream);
let mut substream = mplex.clone().outbound().wait().unwrap();
assert!(substream.write(message).is_ok());
let id = substream.id();
assert_eq!(
substream
.name()
.and_then(|bytes| { String::from_utf8(bytes.to_vec()).ok() }),
Some(id.to_string())
);
let stream = io::Cursor::new(mplex.state.lock().stream.get_ref().clone());
let mplex = Multiplex::listen(stream);
let mut substream = mplex.inbound().wait().unwrap();
assert_eq!(id, substream.id());
assert_eq!(
substream
.name()
.and_then(|bytes| { String::from_utf8(bytes.to_vec()).ok() }),
Some(id.to_string())
);
let mut buf = vec![0; message.len()];
assert!(substream.read(&mut buf).is_ok());
assert_eq!(&buf, message);
}
#[test]
fn can_use_many_streams() {
let stream = io::Cursor::new(Vec::new());
let mplex = Multiplex::dial(stream);
let mut outbound: Vec<Substream<_>> = vec![
mplex.clone().outbound().wait().unwrap(),
mplex.clone().outbound().wait().unwrap(),
mplex.clone().outbound().wait().unwrap(),
mplex.clone().outbound().wait().unwrap(),
mplex.clone().outbound().wait().unwrap(),
];
outbound.sort_by_key(|a| a.id());
for (i, substream) in outbound.iter_mut().enumerate() {
assert!(substream.write(i.to_string().as_bytes()).is_ok());
}
let stream = io::Cursor::new(mplex.state.lock().stream.get_ref().clone());
let mplex = Multiplex::listen(stream);
let mut inbound: Vec<Substream<_>> = vec![
mplex.clone().inbound().wait().unwrap(),
mplex.clone().inbound().wait().unwrap(),
mplex.clone().inbound().wait().unwrap(),
mplex.clone().inbound().wait().unwrap(),
mplex.clone().inbound().wait().unwrap(),
];
inbound.sort_by_key(|a| a.id());
for (substream, outbound) in inbound.iter_mut().zip(outbound.iter()) {
let id = outbound.id();
assert_eq!(id, substream.id());
assert_eq!(
substream
.name()
.and_then(|bytes| { String::from_utf8(bytes.to_vec()).ok() }),
Some(id.to_string())
);
let mut buf = [0; 3];
assert_eq!(substream.read(&mut buf).unwrap(), 1);
}
}
#[test]
fn packets_to_unopened_streams_are_dropped() {
use std::iter;
let message = b"Hello, world!";
// We use a large dummy length to exercise ignoring data longer than `ignore_buffer.len()`
let dummy_length = 1000;
let input = iter::empty()
// Open a stream
.chain(varint::encode(MultiplexHeader::open(0).to_u64()))
// 0-length body (stream has no name)
.chain(varint::encode(0usize))
// "Message"-type packet for an unopened stream
.chain(
varint::encode(
// ID for an unopened stream: 1
MultiplexHeader::message(1, MultiplexEnd::Initiator).to_u64(),
).into_iter(),
)
// Body: `dummy_length` of zeroes
.chain(varint::encode(dummy_length))
.chain(iter::repeat(0).take(dummy_length))
// "Message"-type packet for an opened stream
.chain(
varint::encode(
// ID for an opened stream: 0
MultiplexHeader::message(0, MultiplexEnd::Initiator).to_u64(),
).into_iter(),
)
.chain(varint::encode(message.len()))
.chain(message.iter().cloned())
.collect::<Vec<_>>();
let mplex = Multiplex::listen(io::Cursor::new(input));
let mut substream = mplex.inbound().wait().unwrap();
assert_eq!(substream.id(), 0);
assert_eq!(substream.name(), None);
let mut buf = vec![0; message.len()];
assert!(substream.read(&mut buf).is_ok());
assert_eq!(&buf, message);
}
#[test]
fn can_close_streams() {
use std::iter;
// Dummy data in the body of the close packet (since the de facto protocol is to accept but
// ignore this data)
let dummy_length = 64;
let input = iter::empty()
// Open a stream
.chain(varint::encode(MultiplexHeader::open(0).to_u64()))
// 0-length body (stream has no name)
.chain(varint::encode(0usize))
// Immediately close the stream
.chain(
varint::encode(
// ID for an unopened stream: 1
MultiplexHeader::close(0, MultiplexEnd::Initiator).to_u64(),
).into_iter(),
)
.chain(varint::encode(dummy_length))
.chain(iter::repeat(0).take(dummy_length))
// Send packet to the closed stream
.chain(
varint::encode(
// ID for an opened stream: 0
MultiplexHeader::message(0, MultiplexEnd::Initiator).to_u64(),
).into_iter(),
)
.chain(varint::encode(dummy_length))
.chain(iter::repeat(0).take(dummy_length))
.collect::<Vec<_>>();
let mplex = Multiplex::listen(io::Cursor::new(input));
let mut substream = mplex.inbound().wait().unwrap();
assert_eq!(substream.id(), 0);
assert_eq!(substream.name(), None);
assert_eq!(substream.read(&mut [0; 100][..]).unwrap(), 0);
}
#[test]
fn real_world_data() {
let data: Vec<u8> = vec![
// Open stream 1
8,
0,
// Message for stream 1 (length 20)
10,
20,
19,
47,
109,
117,
108,
116,
105,
115,
116,
114,
101,
97,
109,
47,
49,
46,
48,
46,
48,
10,
];
let mplex = Multiplex::listen(io::Cursor::new(data));
let mut substream = mplex.inbound().wait().unwrap();
assert_eq!(substream.id(), 1);
assert_eq!(substream.name(), None);
let mut out = vec![];
for _ in 0..20 {
let mut buf = [0; 1];
assert_eq!(substream.read(&mut buf[..]).unwrap(), 1);
out.push(buf[0]);
}
assert_eq!(out[0], 19);
assert_eq!(&out[1..0x14 - 1], b"/multistream/1.0.0");
assert_eq!(out[0x14 - 1], 0x0a);
} }
} }

402
multiplex-rs/src/read.rs Normal file
View File

@ -0,0 +1,402 @@
// Copyright 2017 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use {bytes, varint};
use futures::Async;
use futures::task;
use header::{MultiplexHeader, PacketType};
use std::io;
use tokio_io::AsyncRead;
use shared::SubstreamMetadata;
pub enum NextMultiplexState {
NewStream(u32),
ParsingMessageBody(u32),
Ignore,
}
pub enum MultiplexReadState {
Header {
state: varint::DecoderState<u64>,
},
BodyLength {
state: varint::DecoderState<usize>,
next: NextMultiplexState,
},
NewStream {
substream_id: u32,
name: bytes::BytesMut,
remaining_bytes: usize,
},
ParsingMessageBody {
substream_id: u32,
remaining_bytes: usize,
},
Ignore {
remaining_bytes: usize,
},
}
impl Default for MultiplexReadState {
fn default() -> Self {
MultiplexReadState::Header {
state: Default::default(),
}
}
}
fn create_buffer(capacity: usize) -> bytes::BytesMut {
let mut buffer = bytes::BytesMut::with_capacity(capacity);
let zeroes = [0; 1024];
let mut cap = capacity;
while cap > 0 {
let len = cap.min(zeroes.len());
buffer.extend_from_slice(&zeroes[..len]);
cap -= len;
}
buffer
}
pub fn read_stream<'a, O: Into<Option<(u32, &'a mut [u8])>>, T: AsyncRead>(
lock: &mut ::shared::MultiplexShared<T>,
stream_data: O,
) -> io::Result<usize> {
use self::MultiplexReadState::*;
let mut stream_data = stream_data.into();
let stream_has_been_gracefully_closed = stream_data
.as_ref()
.and_then(|&(id, _)| lock.open_streams.get(&id))
.map(|meta| !meta.open())
.unwrap_or(false);
let mut on_block: io::Result<usize> = if stream_has_been_gracefully_closed {
Ok(0)
} else {
Err(io::ErrorKind::WouldBlock.into())
};
loop {
match lock.read_state.take().unwrap_or_default() {
Header {
state: mut varint_state,
} => {
match varint_state.read(&mut lock.stream) {
Ok(Async::Ready(header)) => {
let header = if let Some(header) = header {
header
} else {
return Ok(0);
};
let MultiplexHeader {
substream_id,
packet_type,
} = MultiplexHeader::parse(header).map_err(|err| {
io::Error::new(
io::ErrorKind::Other,
format!("Error parsing header: {:?}", err),
)
})?;
match packet_type {
PacketType::Open => {
lock.read_state = Some(BodyLength {
state: Default::default(),
next: NextMultiplexState::NewStream(substream_id),
})
}
PacketType::Message(_) => {
lock.read_state = Some(BodyLength {
state: Default::default(),
next: NextMultiplexState::ParsingMessageBody(substream_id),
})
}
// NOTE: What's the difference between close and reset?
PacketType::Close(_) | PacketType::Reset(_) => {
lock.read_state = Some(BodyLength {
state: Default::default(),
next: NextMultiplexState::Ignore,
});
lock.close_stream(substream_id);
}
}
}
Ok(Async::NotReady) => {
lock.read_state = Some(Header {
state: varint_state,
});
return on_block;
}
Err(error) => {
return if let varint::Error(varint::ErrorKind::Io(inner), ..) = error {
Err(inner)
} else {
Err(io::Error::new(io::ErrorKind::Other, error.description()))
};
}
}
}
BodyLength {
state: mut varint_state,
next,
} => {
use self::NextMultiplexState::*;
match varint_state
.read(&mut lock.stream)
.map_err(|_| io::Error::new(io::ErrorKind::Other, "Error reading varint"))?
{
Async::Ready(length) => {
// TODO: Limit `length` to prevent resource-exhaustion DOS
let length = if let Some(length) = length {
length
} else {
return Ok(0);
};
lock.read_state = match next {
Ignore => Some(MultiplexReadState::Ignore {
remaining_bytes: length,
}),
NewStream(substream_id) => {
if length == 0 {
lock.to_open.insert(substream_id, None);
None
} else {
Some(MultiplexReadState::NewStream {
// TODO: Uninit buffer
name: create_buffer(length),
remaining_bytes: length,
substream_id,
})
}
}
ParsingMessageBody(substream_id) => {
let is_open = lock.open_streams
.get(&substream_id)
.map(SubstreamMetadata::open)
.unwrap_or_else(|| lock.to_open.contains_key(&substream_id));
if is_open {
Some(MultiplexReadState::ParsingMessageBody {
remaining_bytes: length,
substream_id,
})
} else {
Some(MultiplexReadState::Ignore {
remaining_bytes: length,
})
}
}
};
}
Async::NotReady => {
lock.read_state = Some(BodyLength {
state: varint_state,
next,
});
return on_block;
}
}
}
NewStream {
substream_id,
mut name,
remaining_bytes,
} => {
if remaining_bytes == 0 {
lock.to_open.insert(substream_id, Some(name.freeze()));
lock.read_state = None;
} else {
let cursor_pos = name.len() - remaining_bytes;
let consumed = lock.stream.read(&mut name[cursor_pos..]);
match consumed {
Ok(consumed) => {
let new_remaining = remaining_bytes - consumed;
lock.read_state = Some(NewStream {
substream_id,
name,
remaining_bytes: new_remaining,
});
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
lock.read_state = Some(NewStream {
substream_id,
name,
remaining_bytes,
});
return on_block;
}
Err(other) => {
lock.read_state = Some(NewStream {
substream_id,
name,
remaining_bytes,
});
return Err(other);
}
}
}
}
ParsingMessageBody {
substream_id,
remaining_bytes,
} => {
if let Some((ref mut id, ref mut buf)) = stream_data {
use MultiplexReadState::*;
if remaining_bytes == 0 {
lock.read_state = None;
return on_block;
} else if substream_id == *id {
let number_read = *on_block.as_ref().unwrap_or(&0);
if buf.len() == 0 {
return Ok(0);
} else if number_read >= buf.len() {
lock.read_state = Some(ParsingMessageBody {
substream_id,
remaining_bytes,
});
return on_block;
}
let read_result = {
// We know this won't panic because of the earlier
// `number_read >= buf.len()` check
let new_len = (buf.len() - number_read).min(remaining_bytes);
let slice = &mut buf[number_read..number_read + new_len];
lock.stream.read(slice)
};
match read_result {
Ok(consumed) => {
let new_remaining = remaining_bytes - consumed;
lock.read_state = Some(ParsingMessageBody {
substream_id,
remaining_bytes: new_remaining,
});
on_block = Ok(number_read + consumed);
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
lock.read_state = Some(ParsingMessageBody {
substream_id,
remaining_bytes,
});
return on_block;
}
Err(other) => {
lock.read_state = Some(ParsingMessageBody {
substream_id,
remaining_bytes,
});
return Err(other);
}
}
} else {
lock.read_state = Some(ParsingMessageBody {
substream_id,
remaining_bytes,
});
if let Some(task) = lock.open_streams
.get(&substream_id)
.and_then(SubstreamMetadata::read_task)
{
task.notify();
}
let write = lock.open_streams
.get(id)
.and_then(SubstreamMetadata::write_task)
.cloned();
lock.open_streams.insert(
*id,
SubstreamMetadata::Open {
read: Some(task::current()),
write,
},
);
// We cannot make progress here, another stream has to accept this packet
return on_block;
}
} else {
lock.read_state = Some(ParsingMessageBody {
substream_id,
remaining_bytes,
});
// We cannot make progress here, a stream has to accept this packet
return on_block;
}
}
Ignore {
mut remaining_bytes,
} => {
let mut ignore_buf: [u8; 256] = [0; 256];
loop {
if remaining_bytes == 0 {
lock.read_state = None;
break;
} else {
let new_len = ignore_buf.len().min(remaining_bytes);
match lock.stream.read(&mut ignore_buf[..new_len]) {
Ok(consumed) => {
remaining_bytes -= consumed;
lock.read_state = Some(Ignore {
remaining_bytes: remaining_bytes,
});
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
lock.read_state = Some(Ignore { remaining_bytes });
return on_block;
}
Err(other) => {
lock.read_state = Some(Ignore { remaining_bytes });
return Err(other);
}
}
}
}
}
}
}
}

108
multiplex-rs/src/shared.rs Normal file
View File

@ -0,0 +1,108 @@
// Copyright 2017 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use read::MultiplexReadState;
use write::MultiplexWriteState;
use std::collections::HashMap;
use bytes::Bytes;
use arrayvec::ArrayVec;
use futures::task::Task;
const BUF_SIZE: usize = 1024;
pub type ByteBuf = ArrayVec<[u8; BUF_SIZE]>;
pub enum SubstreamMetadata {
Closed,
Open {
read: Option<Task>,
write: Option<Task>,
},
}
impl SubstreamMetadata {
pub fn open(&self) -> bool {
match *self {
SubstreamMetadata::Closed => false,
SubstreamMetadata::Open { .. } => true,
}
}
pub fn read_task(&self) -> Option<&Task> {
match *self {
SubstreamMetadata::Closed => None,
SubstreamMetadata::Open { ref read, .. } => read.as_ref(),
}
}
pub fn write_task(&self) -> Option<&Task> {
match *self {
SubstreamMetadata::Closed => None,
SubstreamMetadata::Open { ref write, .. } => write.as_ref(),
}
}
}
// TODO: Split reading and writing into different structs and have information shared between the
// two in a `RwLock`, since `open_streams` and `to_open` are mostly read-only.
pub struct MultiplexShared<T> {
// We use `Option` in order to take ownership of heap allocations within `DecoderState` and
// `BytesMut`. If this is ever observably `None` then something has panicked or the underlying
// stream returned an error.
pub read_state: Option<MultiplexReadState>,
pub write_state: Option<MultiplexWriteState>,
pub stream: T,
// true if the stream is open, false otherwise
pub open_streams: HashMap<u32, SubstreamMetadata>,
// TODO: Should we use a version of this with a fixed size that doesn't allocate and return
// `WouldBlock` if it's full?
pub to_open: HashMap<u32, Option<Bytes>>,
}
impl<T> MultiplexShared<T> {
pub fn new(stream: T) -> Self {
MultiplexShared {
read_state: Default::default(),
write_state: Default::default(),
open_streams: Default::default(),
to_open: Default::default(),
stream: stream,
}
}
pub fn open_stream(&mut self, id: u32) -> bool {
self.open_streams
.entry(id)
.or_insert(SubstreamMetadata::Open {
read: None,
write: None,
})
.open()
}
pub fn close_stream(&mut self, id: u32) {
self.open_streams.insert(id, SubstreamMetadata::Closed);
}
}
pub fn buf_from_slice(slice: &[u8]) -> ByteBuf {
slice.iter().cloned().take(BUF_SIZE).collect()
}

188
multiplex-rs/src/write.rs Normal file
View File

@ -0,0 +1,188 @@
// Copyright 2017 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use shared::{ByteBuf, MultiplexShared, SubstreamMetadata};
use header::MultiplexHeader;
use varint;
use futures::task;
use std::io;
use tokio_io::AsyncWrite;
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum RequestType {
Meta,
Substream,
}
#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub struct WriteRequest {
header: MultiplexHeader,
request_type: RequestType,
}
impl WriteRequest {
pub fn substream(header: MultiplexHeader) -> Self {
WriteRequest {
header,
request_type: RequestType::Substream,
}
}
pub fn meta(header: MultiplexHeader) -> Self {
WriteRequest {
header,
request_type: RequestType::Meta,
}
}
}
#[derive(Default, Debug)]
pub struct MultiplexWriteState {
current: Option<(WriteRequest, MultiplexWriteStateInner)>,
queued: Option<WriteRequest>,
// TODO: Actually close these
to_close: Vec<u32>,
}
#[derive(Debug)]
pub enum MultiplexWriteStateInner {
WriteHeader { state: varint::EncoderState<u64> },
BodyLength { state: varint::EncoderState<usize> },
Body { size: usize },
}
pub fn write_stream<T: AsyncWrite>(
lock: &mut MultiplexShared<T>,
write_request: WriteRequest,
buf: &mut io::Cursor<ByteBuf>,
) -> io::Result<usize> {
use futures::Async;
use num_traits::cast::ToPrimitive;
use varint::WriteState;
use write::MultiplexWriteStateInner::*;
let mut on_block = Err(io::ErrorKind::WouldBlock.into());
let mut write_state = lock.write_state.take().unwrap_or_default();
let (request, mut state) = write_state.current.take().unwrap_or_else(|| {
(
write_request,
MultiplexWriteStateInner::WriteHeader {
state: varint::EncoderState::new(write_request.header.to_u64()),
},
)
});
let id = write_request.header.substream_id;
match (request.request_type, write_request.request_type) {
(RequestType::Substream, RequestType::Substream) if request.header.substream_id != id => {
let read = lock.open_streams
.get(&id)
.and_then(SubstreamMetadata::read_task)
.cloned();
if let Some(task) = lock.open_streams
.get(&request.header.substream_id)
.and_then(SubstreamMetadata::write_task)
{
task.notify();
}
lock.open_streams.insert(
id,
SubstreamMetadata::Open {
write: Some(task::current()),
read,
},
);
lock.write_state = Some(write_state);
return on_block;
}
(RequestType::Substream, RequestType::Meta)
| (RequestType::Meta, RequestType::Substream) => {
lock.write_state = Some(write_state);
return on_block;
}
_ => {}
}
loop {
// Err = should return, Ok = continue
let new_state = match state {
WriteHeader {
state: mut inner_state,
} => match inner_state
.write(&mut lock.stream)
.map_err(|_| io::ErrorKind::Other)?
{
Async::Ready(WriteState::Done(_)) => Ok(BodyLength {
state: varint::EncoderState::new(buf.get_ref().len()),
}),
Async::Ready(WriteState::Pending(_)) | Async::NotReady => {
Err(Some(WriteHeader { state: inner_state }))
}
},
BodyLength {
state: mut inner_state,
} => match inner_state
.write(&mut lock.stream)
.map_err(|_| io::ErrorKind::Other)?
{
Async::Ready(WriteState::Done(_)) => Ok(Body {
size: inner_state.source().to_usize().unwrap_or(::std::usize::MAX),
}),
Async::Ready(WriteState::Pending(_)) => Ok(BodyLength { state: inner_state }),
Async::NotReady => Err(Some(BodyLength { state: inner_state })),
},
Body { size } => {
if buf.position() == buf.get_ref().len() as u64 {
Err(None)
} else {
match lock.stream.write(&buf.get_ref()[buf.position() as usize..]) {
Ok(just_written) => {
let cur_pos = buf.position();
buf.set_position(cur_pos + just_written as u64);
on_block = Ok(on_block.unwrap_or(0) + just_written);
Ok(Body {
size: size - just_written,
})
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
Err(Some(Body { size }))
}
Err(other) => {
return Err(other);
}
}
}
}
};
match new_state {
Ok(new_state) => state = new_state,
Err(new_state) => {
write_state.current = new_state.map(|state| (request, state));
lock.write_state = Some(write_state);
return on_block;
}
}
}
}

View File

@ -5,6 +5,8 @@ authors = ["Parity Technologies <admin@parity.io>"]
[dependencies] [dependencies]
num-bigint = "0.1.40" num-bigint = "0.1.40"
num-traits = "0.1.40"
bytes = "0.4.5" bytes = "0.4.5"
tokio-io = "0.1" tokio-io = "0.1"
futures = "0.1" futures = "0.1"
error-chain = "0.11.0"

View File

@ -1,134 +1,404 @@
#![warn(missing_docs)]
//! Encoding and decoding state machines for protobuf varints
// TODO: Non-allocating `BigUint`?
extern crate num_bigint; extern crate num_bigint;
extern crate num_traits;
extern crate tokio_io; extern crate tokio_io;
extern crate bytes; extern crate bytes;
extern crate futures; extern crate futures;
#[macro_use]
extern crate error_chain;
use bytes::BytesMut; use bytes::BytesMut;
use futures::{Poll, Async};
use num_bigint::BigUint; use num_bigint::BigUint;
use tokio_io::AsyncRead; use num_traits::ToPrimitive;
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_io::codec::Decoder; use tokio_io::codec::Decoder;
use std::io; use std::io;
use std::io::prelude::*; use std::io::prelude::*;
// TODO: error-chain mod errors {
pub struct ParseError; error_chain! {
errors {
ParseError {
description("error parsing varint")
display("error parsing varint")
}
WriteError {
description("error writing varint")
display("error writing varint")
}
}
#[derive(Default)] foreign_links {
pub struct DecoderState { Io(::std::io::Error);
// TODO: Non-allocating `BigUint`?
accumulator: BigUint,
shift: usize,
}
impl DecoderState {
pub fn new() -> Self {
Default::default()
}
fn decode_one(mut self, byte: u8) -> Result<BigUint, Self> {
self.accumulator = self.accumulator | (BigUint::from(byte & 0x7F) << self.shift);
self.shift += 7;
if byte & 0x80 == 0 {
Ok(self.accumulator)
} else {
Err(self)
} }
} }
}
// Why the weird type signature? Well, `BigUint` owns its storage, and we don't want to clone pub use errors::{Error, ErrorKind};
// it. So, we want the accumulator to be moved out when it is ready. We could have also used
// `Option`, but this means that it's not possible to end up in an inconsistent state
// (`shift != 0 && accumulator.is_none()`).
pub fn read<R: AsyncRead>(self, mut input: R) -> Result<Result<BigUint, Self>, ParseError> {
let mut state = self;
loop {
// We read one at a time to prevent consuming too much of the buffer.
let mut buffer: [u8; 1] = [0];
match input.read_exact(&mut buffer) { const USABLE_BITS_PER_BYTE: usize = 7;
Ok(_) => {
state = match state.decode_one(buffer[0]) { /// The state struct for the varint-to-bytes FSM
Ok(out) => break Ok(Ok(out)), #[derive(Debug)]
Err(state) => state, pub struct EncoderState<T> {
}; source: T,
} // A "chunk" is a section of the contained `BigUint` `USABLE_BITS_PER_BYTE` bits long
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => break Ok(Err(state)), num_chunks: usize,
Err(_) => break Err(ParseError), cur_chunk: usize,
}
/// Whether or not the varint writing was completed
#[derive(Debug, Eq, PartialEq, Copy, Clone)]
pub enum WriteState {
/// The encoder has finished writing
Done(usize),
/// The encoder still must write more bytes
Pending(usize),
}
fn ceil_div(a: usize, b: usize) -> usize {
(a + b - 1) / b
}
/// A trait to get the minimum number of bits required to represent a number
pub trait Bits {
/// The minimum number of bits required to represent `self`
fn bits(&self) -> usize;
}
impl Bits for BigUint {
fn bits(&self) -> usize {
BigUint::bits(self)
}
}
macro_rules! impl_bits {
($t:ty) => {
impl Bits for $t {
fn bits(&self) -> usize {
(std::mem::size_of::<$t>() * 8) - self.leading_zeros() as usize
} }
} }
} }
} }
#[derive(Default)] impl_bits!(usize);
pub struct VarintDecoder { impl_bits!(u64);
state: Option<DecoderState>, impl_bits!(u32);
impl_bits!(u16);
impl_bits!(u8);
/// Helper trait to allow multiple integer types to be encoded
pub trait EncoderHelper: Sized {
/// Write as much as possible of the inner integer to the output `AsyncWrite`
fn write<W: AsyncWrite>(encoder: &mut EncoderState<Self>, output: W)
-> Poll<WriteState, Error>;
} }
impl VarintDecoder { /// Helper trait to allow multiple integer types to be encoded
pub trait DecoderHelper: Sized {
/// Decode a single byte
fn decode_one(decoder: &mut DecoderState<Self>, byte: u8) -> Option<Self>;
/// Read as much of the varint as possible
fn read<R: AsyncRead>(decoder: &mut DecoderState<Self>, input: R) -> Poll<Option<Self>, Error>;
}
macro_rules! impl_decoderstate {
($t:ty) => { impl_decoderstate!($t, <$t>::from, |v| v); };
($t:ty, $make_fn:expr) => { impl_decoderstate!($t, $make_fn, $make_fn); };
($t:ty, $make_fn:expr, $make_shift_fn:expr) => {
impl DecoderHelper for $t {
#[inline]
fn decode_one(decoder: &mut DecoderState<Self>, byte: u8) -> Option<$t> {
decoder.accumulator.take().and_then(|accumulator| {
let out = accumulator |
(
$make_fn(byte & 0x7F) <<
$make_shift_fn(decoder.shift * USABLE_BITS_PER_BYTE)
);
decoder.shift += 1;
if byte & 0x80 == 0 {
Some(out)
} else {
decoder.accumulator = AccumulatorState::InProgress(out);
None
}
})
}
fn read<R: AsyncRead>(
decoder: &mut DecoderState<Self>,
mut input: R
) -> Poll<Option<Self>, Error> {
if decoder.accumulator == AccumulatorState::Finished {
return Err(Error::with_chain(
io::Error::new(
io::ErrorKind::Other,
"Attempted to parse a second varint (create a new instance!)",
),
ErrorKind::ParseError,
));
}
loop {
// We read one at a time to prevent consuming too much of the buffer.
let mut buffer: [u8; 1] = [0];
match input.read_exact(&mut buffer) {
Ok(()) => {
if let Some(out) = Self::decode_one(decoder, buffer[0]) {
break Ok(Async::Ready(Some(out)));
}
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
break Ok(Async::NotReady);
}
Err(inner) => if decoder.accumulator == AccumulatorState::NotStarted {
break Ok(Async::Ready(None));
} else {
break Err(Error::with_chain(inner, ErrorKind::ParseError))
},
}
}
}
}
}
}
macro_rules! impl_encoderstate {
($t:ty) => { impl_encoderstate!($t, <$t>::from); };
($t:ty, $make_fn:expr) => {
impl EncoderHelper for $t {
/// Write as much as possible of the inner integer to the output `AsyncWrite`
fn write<W: AsyncWrite>(
encoder: &mut EncoderState<Self>,
mut output: W,
) -> Poll<WriteState, Error> {
fn encode_one(encoder: &EncoderState<$t>) -> Option<u8> {
let last_chunk = encoder.num_chunks - 1;
if encoder.cur_chunk > last_chunk {
return None;
}
let masked = (&encoder.source >> (encoder.cur_chunk * USABLE_BITS_PER_BYTE)) &
$make_fn((1 << USABLE_BITS_PER_BYTE) - 1usize);
let masked = masked.to_u8().expect(
"Masked with 0b0111_1111, is less than u8::MAX, QED",
);
if encoder.cur_chunk == last_chunk {
Some(masked)
} else {
Some(masked | (1 << USABLE_BITS_PER_BYTE))
}
}
let mut written = 0usize;
loop {
if let Some(byte) = encode_one(&encoder) {
let buffer: [u8; 1] = [byte];
match output.write_all(&buffer) {
Ok(()) => {
written += 1;
encoder.cur_chunk += 1;
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
break if written == 0 {
Ok(Async::NotReady)
} else {
Ok(Async::Ready(WriteState::Pending(written)))
};
}
Err(inner) => break Err(
Error::with_chain(inner, ErrorKind::WriteError)
),
}
} else {
break Ok(Async::Ready(WriteState::Done(written)));
}
}
}
}
}
}
impl_encoderstate!(usize);
impl_encoderstate!(BigUint);
impl_encoderstate!(u64, (|val| val as u64));
impl_encoderstate!(u32, (|val| val as u32));
impl_decoderstate!(usize);
impl_decoderstate!(BigUint);
impl_decoderstate!(u64, (|val| val as u64));
impl_decoderstate!(u32, (|val| val as u32));
impl<T> EncoderState<T> {
pub fn source(&self) -> &T {
&self.source
}
}
impl<T: Bits> EncoderState<T> {
/// Create a new encoder
pub fn new(inner: T) -> Self {
let bits = inner.bits();
EncoderState {
source: inner,
num_chunks: ceil_div(bits, USABLE_BITS_PER_BYTE).max(1),
cur_chunk: 0,
}
}
}
impl<T: EncoderHelper> EncoderState<T> {
/// Write as much as possible of the inner integer to the output `AsyncWrite`
pub fn write<W: AsyncWrite>(&mut self, output: W) -> Poll<WriteState, Error> {
T::write(self, output)
}
}
#[derive(Debug, PartialEq, Eq)]
enum AccumulatorState<T> {
InProgress(T),
NotStarted,
Finished,
}
impl<T: Default> AccumulatorState<T> {
fn take(&mut self) -> Option<T> {
use std::mem;
use AccumulatorState::*;
match mem::replace(self, AccumulatorState::Finished) {
InProgress(inner) => Some(inner),
NotStarted => Some(Default::default()),
Finished => None,
}
}
}
/// The state struct for the varint bytes-to-bigint FSM
#[derive(Debug)]
pub struct DecoderState<T> {
accumulator: AccumulatorState<T>,
shift: usize,
}
impl<T: Default> Default for DecoderState<T> {
fn default() -> Self {
DecoderState {
accumulator: AccumulatorState::NotStarted,
shift: 0,
}
}
}
impl<T: Default> DecoderState<T> {
/// Make a new decoder
pub fn new() -> Self { pub fn new() -> Self {
Default::default() Default::default()
} }
} }
impl Decoder for VarintDecoder { impl<T: DecoderHelper> DecoderState<T> {
type Item = BigUint; /// Make a new decoder
pub fn read<R: AsyncRead>(&mut self, input: R) -> Poll<Option<T>, Error> {
T::read(self, input)
}
}
/// Wrapper around `DecoderState` to make a `tokio` `Decoder`
#[derive(Debug)]
pub struct VarintDecoder<T> {
state: Option<DecoderState<T>>,
}
impl<T> Default for VarintDecoder<T> {
fn default() -> Self {
VarintDecoder { state: None }
}
}
impl<T> VarintDecoder<T> {
/// Make a new decoder
pub fn new() -> Self {
Default::default()
}
}
impl<T: Default + DecoderHelper> Decoder for VarintDecoder<T> {
type Item = T;
type Error = io::Error; type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> { fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
loop { loop {
if src.len() == 0 { if src.is_empty() && self.state.is_some() {
break Err(io::Error::from(io::ErrorKind::UnexpectedEof)); break Err(io::Error::from(io::ErrorKind::UnexpectedEof));
} else { } else {
// We know that the length is not 0, so this cannot fail. // We know that the length is not 0, so this cannot fail.
let first_byte = src.split_to(1)[0]; let first_byte = src.split_to(1)[0];
let new_state = match self.state.take() { let mut state = self.state.take().unwrap_or_default();
Some(state) => state.decode_one(first_byte), let out = T::decode_one(&mut state, first_byte);
None => DecoderState::new().decode_one(first_byte),
};
match new_state { if let Some(out) = out {
Ok(out) => break Ok(Some(out)), break Ok(Some(out));
Err(state) => self.state = Some(state), } else {
self.state = Some(state);
} }
} }
} }
} }
} }
pub fn decode<R: Read>(stream: R) -> io::Result<BigUint> { /// Syncronously decode a number from a `Read`
let mut out = BigUint::from(0u8); pub fn decode<R: Read, T: Default + DecoderHelper>(mut input: R) -> errors::Result<T> {
let mut shift = 0; let mut decoder = DecoderState::default();
let mut finished_cleanly = false;
for i in stream.bytes() { loop {
let i = i?; // We read one at a time to prevent consuming too much of the buffer.
let mut buffer: [u8; 1] = [0];
out = out | (BigUint::from(i & 0x7F) << shift); match input.read_exact(&mut buffer) {
shift += 7; Ok(()) => {
if let Some(out) = T::decode_one(&mut decoder, buffer[0]) {
if i & 0x80 == 0 { break Ok(out);
finished_cleanly = true; }
break; }
Err(inner) => break Err(Error::with_chain(inner, ErrorKind::ParseError)),
} }
} }
}
if finished_cleanly { /// Syncronously decode a number from a `Read`
Ok(out) pub fn encode<T: EncoderHelper + Bits>(input: T) -> Vec<u8> {
} else { let mut encoder = EncoderState::new(input);
Err(io::Error::from(io::ErrorKind::UnexpectedEof)) let mut out = io::Cursor::new(Vec::with_capacity(1));
match T::write(&mut encoder, &mut out).expect("Writing to a vec should never fail, Q.E.D") {
Async::Ready(_) => out.into_inner(),
Async::NotReady => unreachable!(),
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{decode, VarintDecoder}; use super::{decode, VarintDecoder, EncoderState};
use tokio_io::codec::FramedRead; use tokio_io::codec::FramedRead;
use num_bigint::BigUint; use num_bigint::BigUint;
use futures::{Future, Stream}; use futures::{Future, Stream};
#[test] #[test]
fn can_decode_basic_uint() { fn can_decode_basic_biguint() {
assert_eq!( assert_eq!(
BigUint::from(300u16), BigUint::from(300u16),
decode(&[0b10101100, 0b00000010][..]).unwrap() decode(&[0b10101100, 0b00000010][..]).unwrap()
@ -136,7 +406,7 @@ mod tests {
} }
#[test] #[test]
fn can_decode_basic_uint_async() { fn can_decode_basic_biguint_async() {
let result = FramedRead::new(&[0b10101100, 0b00000010][..], VarintDecoder::new()) let result = FramedRead::new(&[0b10101100, 0b00000010][..], VarintDecoder::new())
.into_future() .into_future()
.map(|(out, _)| out) .map(|(out, _)| out)
@ -146,11 +416,77 @@ mod tests {
assert_eq!(result.unwrap(), Some(BigUint::from(300u16))); assert_eq!(result.unwrap(), Some(BigUint::from(300u16)));
} }
#[test]
fn can_decode_trivial_usize_async() {
let result = FramedRead::new(&[1][..], VarintDecoder::new())
.into_future()
.map(|(out, _)| out)
.map_err(|(out, _)| out)
.wait();
assert_eq!(result.unwrap(), Some(1usize));
}
#[test]
fn can_decode_basic_usize_async() {
let result = FramedRead::new(&[0b10101100, 0b00000010][..], VarintDecoder::new())
.into_future()
.map(|(out, _)| out)
.map_err(|(out, _)| out)
.wait();
assert_eq!(result.unwrap(), Some(300usize));
}
#[test]
fn can_encode_basic_biguint_async() {
use std::io::Cursor;
use futures::Async;
use super::WriteState;
let mut out = vec![0u8; 2];
{
let writable: Cursor<&mut [_]> = Cursor::new(&mut out);
let mut state = EncoderState::new(BigUint::from(300usize));
assert_eq!(
state.write(writable).unwrap(),
Async::Ready(WriteState::Done(2))
);
}
assert_eq!(out, vec![0b10101100, 0b00000010]);
}
#[test]
fn can_encode_basic_usize_async() {
use std::io::Cursor;
use futures::Async;
use super::WriteState;
let mut out = vec![0u8; 2];
{
let writable: Cursor<&mut [_]> = Cursor::new(&mut out);
let mut state = EncoderState::new(300usize);
assert_eq!(
state.write(writable).unwrap(),
Async::Ready(WriteState::Done(2))
);
}
assert_eq!(out, vec![0b10101100, 0b00000010]);
}
#[test] #[test]
fn unexpected_eof_async() { fn unexpected_eof_async() {
use std::io; use std::io;
let result = FramedRead::new(&[0b10101100, 0b10000010][..], VarintDecoder::new()) let result = FramedRead::new(&[0b10101100, 0b10000010][..], VarintDecoder::<usize>::new())
.into_future() .into_future()
.map(|(out, _)| out) .map(|(out, _)| out)
.map_err(|(out, _)| out) .map_err(|(out, _)| out)