diff --git a/Cargo.toml b/Cargo.toml index 3c1e4621..ef17a95e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,11 +4,7 @@ members = [ "libp2p-ping", "libp2p-secio", "libp2p-swarm", - "libp2p-transport", - "libp2p-host", "libp2p-tcp-transport", - "libp2p-stream-muxer", - "multihash", "multistream-select", "datastore", "rw-stream-sink", diff --git a/circular-buffer/src/lib.rs b/circular-buffer/src/lib.rs index b7f8bd05..00eaebc0 100644 --- a/circular-buffer/src/lib.rs +++ b/circular-buffer/src/lib.rs @@ -9,62 +9,109 @@ extern crate smallvec; -use std::ops::{Deref, DerefMut, Drop}; +use std::ops::Drop; use std::mem::ManuallyDrop; use smallvec::Array; +use owned_slice::OwnedSlice; + /// A slice that owns its elements, but not their storage. This is useful for things like /// `Vec::retain` and `CircularBuffer::pop_slice`, since these functions can return a slice but the /// elements of these slices would be leaked after the slice goes out of scope. `OwnedSlice` simply /// manually drops all its elements when it goes out of scope. -#[derive(Debug, Eq, PartialEq)] -pub struct OwnedSlice<'a, T: 'a>(&'a mut [T]); +pub mod owned_slice { + use std::ops::{Deref, DerefMut, Drop}; + use std::mem::ManuallyDrop; -impl<'a, T: 'a> OwnedSlice<'a, T> { - /// Construct an owned slice from a mutable slice pointer. - /// - /// # Unsafety - /// You must ensure that the memory pointed to by `inner` will not be accessible after the - /// lifetime of the `OwnedSlice`. - pub unsafe fn new(inner: &'a mut [T]) -> Self { - OwnedSlice(inner) + /// A slice that owns its elements, but not their storage. This is useful for things like + /// `Vec::retain` and `CircularBuffer::pop_slice`, since these functions can return a slice but + /// the elements of these slices would be leaked after the slice goes out of scope. `OwnedSlice` + /// simply manually drops all its elements when it goes out of scope. + #[derive(Debug, Eq, PartialEq)] + pub struct OwnedSlice<'a, T: 'a>(&'a mut [T]); + + /// Owning iterator for `OwnedSlice`. + pub struct IntoIter<'a, T: 'a> { + slice: ManuallyDrop>, + index: usize, } -} -impl<'a, T> AsRef<[T]> for OwnedSlice<'a, T> { - fn as_ref(&self) -> &[T] { - self.0 + impl<'a, T> Iterator for IntoIter<'a, T> { + type Item = T; + + fn next(&mut self) -> Option { + use std::ptr; + + let index = self.index; + + if index >= self.slice.len() { + return None; + } + + self.index += 1; + + unsafe { Some(ptr::read(&self.slice[index])) } + } } -} -impl<'a, T> AsMut<[T]> for OwnedSlice<'a, T> { - fn as_mut(&mut self) -> &mut [T] { - self.0 + impl<'a, T: 'a> IntoIterator for OwnedSlice<'a, T> { + type Item = T; + type IntoIter = IntoIter<'a, T>; + + fn into_iter(self) -> Self::IntoIter { + IntoIter { + slice: ManuallyDrop::new(self), + index: 0, + } + } } -} -impl<'a, T> Deref for OwnedSlice<'a, T> { - type Target = [T]; - - fn deref(&self) -> &Self::Target { - self.0 + impl<'a, T: 'a> OwnedSlice<'a, T> { + /// Construct an owned slice from a mutable slice pointer. + /// + /// # Unsafety + /// You must ensure that the memory pointed to by `inner` will not be accessible after the + /// lifetime of the `OwnedSlice`. + pub unsafe fn new(inner: &'a mut [T]) -> Self { + OwnedSlice(inner) + } } -} -impl<'a, T> DerefMut for OwnedSlice<'a, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - self.0 + impl<'a, T> AsRef<[T]> for OwnedSlice<'a, T> { + fn as_ref(&self) -> &[T] { + self.0 + } } -} -impl<'a, T> Drop for OwnedSlice<'a, T> { - fn drop(&mut self) { - use std::ptr; + impl<'a, T> AsMut<[T]> for OwnedSlice<'a, T> { + fn as_mut(&mut self) -> &mut [T] { + self.0 + } + } - for element in self.iter_mut() { - unsafe { - ptr::drop_in_place(element); + impl<'a, T> Deref for OwnedSlice<'a, T> { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + self.0 + } + } + + impl<'a, T> DerefMut for OwnedSlice<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0 + } + } + + impl<'a, T> Drop for OwnedSlice<'a, T> { + fn drop(&mut self) { + use std::ptr; + + for element in self.iter_mut() { + unsafe { + ptr::drop_in_place(element); + } } } } @@ -82,6 +129,12 @@ pub struct CircularBuffer { len: usize, } +impl Default for CircularBuffer { + fn default() -> Self { + Self::new() + } +} + impl PartialEq for CircularBuffer where B::Item: PartialEq, @@ -97,7 +150,7 @@ where } } - return true; + true } } @@ -148,7 +201,9 @@ impl CircularBuffer { /// when the slice goes out of scope), if you're using non-`Drop` types you can use /// `pop_slice_leaky`. pub fn pop_slice(&mut self) -> Option> { - self.pop_slice_leaky().map(OwnedSlice) + self.pop_slice_leaky().map( + |x| unsafe { OwnedSlice::new(x) }, + ) } /// Pop a slice containing the maximum possible contiguous number of elements. Since this buffer @@ -357,18 +412,35 @@ impl CircularBuffer { } } - /// Get a borrow to an element at an index unsafely (causes undefined behaviour if the index is - /// out of bounds). + /// Get a borrow to an element at an index unsafely (behaviour is undefined if the index is out + /// of bounds). pub unsafe fn get_unchecked(&self, index: usize) -> &B::Item { - use std::mem; - - mem::transmute(self.buffer.ptr().offset( + &*self.buffer.ptr().offset( ((index + self.start) % B::size()) as isize, - )) + ) + } + + /// Get a mutable borrow to an element at an index safely (if the index is out of bounds, return + /// `None`). + pub fn get_mut(&mut self, index: usize) -> Option<&mut B::Item> { + if index < self.len { + unsafe { Some(self.get_unchecked_mut(index)) } + } else { + None + } + } + + /// Get a mutable borrow to an element at an index unsafely (behaviour is undefined if the index + /// is out of bounds). + pub unsafe fn get_unchecked_mut(&mut self, index: usize) -> &mut B::Item { + &mut *self.buffer.ptr_mut().offset( + ((index + self.start) % B::size()) as + isize, + ) } // This is not unsafe because it can only leak data, not cause uninit to be read. - fn advance(&mut self, by: usize) { + pub fn advance(&mut self, by: usize) { assert!(by <= self.len); self.start = (self.start + by) % B::size(); @@ -376,6 +448,39 @@ impl CircularBuffer { } } +impl std::ops::Index for CircularBuffer { + type Output = B::Item; + + fn index(&self, index: usize) -> &Self::Output { + if let Some(out) = self.get(index) { + out + } else { + panic!( + "index out of bounds: the len is {} but the index is {}", + self.len, + index + ); + } + } +} + +impl std::ops::IndexMut for CircularBuffer { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + // We need to do this because borrowck isn't smart enough to understand enum variants + let len = self.len; + + if let Some(out) = self.get_mut(index) { + return out; + } else { + panic!( + "index out of bounds: the len is {} but the index is {}", + len, + index + ); + } + } +} + impl Drop for CircularBuffer { fn drop(&mut self) { while self.pop_slice().is_some() {} diff --git a/libp2p-stream-muxer/Cargo.toml b/libp2p-stream-muxer/Cargo.toml deleted file mode 100644 index 51a831d1..00000000 --- a/libp2p-stream-muxer/Cargo.toml +++ /dev/null @@ -1,8 +0,0 @@ -[package] -name = "libp2p-stream-muxer" -version = "0.1.0" -authors = ["Vurich "] - -[dependencies] -futures = "0.1" -tokio-io = "0.1" diff --git a/libp2p-stream-muxer/src/lib.rs b/libp2p-stream-muxer/src/lib.rs deleted file mode 100644 index 5479c1bd..00000000 --- a/libp2p-stream-muxer/src/lib.rs +++ /dev/null @@ -1,22 +0,0 @@ -extern crate tokio_io; -extern crate futures; - -use futures::stream::Stream; -use tokio_io::{AsyncRead, AsyncWrite}; - -pub trait StreamMuxer { - type Substream: AsyncRead + AsyncWrite; - type InboundSubstreams: Stream; - type OutboundSubstreams: Stream; - - fn inbound(&mut self) -> Self::InboundSubstreams; - fn outbound(&mut self) -> Self::OutboundSubstreams; -} - -#[cfg(test)] -mod tests { - #[test] - fn it_works() { - assert_eq!(2 + 2, 4); - } -} diff --git a/multiplex-rs/Cargo.toml b/multiplex-rs/Cargo.toml index 9f3ccfca..cbfcabe1 100644 --- a/multiplex-rs/Cargo.toml +++ b/multiplex-rs/Cargo.toml @@ -10,6 +10,8 @@ num-bigint = "0.1.40" tokio-io = "0.1" futures = "0.1" parking_lot = "0.4.8" -libp2p-stream-muxer = { path = "../libp2p-stream-muxer" } +arrayvec = "0.4.6" +rand = "0.3.17" +libp2p-swarm = { path = "../libp2p-swarm" } varint = { path = "../varint-rs" } -circular-buffer = { path = "../circular-buffer" } +error-chain = "0.11.0" diff --git a/multiplex-rs/README.md b/multiplex-rs/README.md new file mode 100644 index 00000000..c0007bb1 --- /dev/null +++ b/multiplex-rs/README.md @@ -0,0 +1,3 @@ +# Multiplex + +A Rust implementation of [multiplex](https://github.com/maxogden/multiplex). diff --git a/multiplex-rs/src/header.rs b/multiplex-rs/src/header.rs new file mode 100644 index 00000000..74408c3a --- /dev/null +++ b/multiplex-rs/src/header.rs @@ -0,0 +1,145 @@ +// Copyright 2017 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +const FLAG_BITS: usize = 3; +const FLAG_MASK: usize = (1usize << FLAG_BITS) - 1; + +pub mod errors { + error_chain! { + errors { + ParseError + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum MultiplexEnd { + Initiator, + Receiver, +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct MultiplexHeader { + pub packet_type: PacketType, + pub substream_id: u32, +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum PacketType { + Open, + Close(MultiplexEnd), + Reset(MultiplexEnd), + Message(MultiplexEnd), +} + +impl MultiplexHeader { + pub fn open(id: u32) -> Self { + MultiplexHeader { + substream_id: id, + packet_type: PacketType::Open, + } + } + + pub fn close(id: u32, end: MultiplexEnd) -> Self { + MultiplexHeader { + substream_id: id, + packet_type: PacketType::Close(end), + } + } + + pub fn reset(id: u32, end: MultiplexEnd) -> Self { + MultiplexHeader { + substream_id: id, + packet_type: PacketType::Reset(end), + } + } + + pub fn message(id: u32, end: MultiplexEnd) -> Self { + MultiplexHeader { + substream_id: id, + packet_type: PacketType::Message(end), + } + } + + // TODO: Use `u128` or another large integer type instead of bigint since we never use more than + // `pointer width + FLAG_BITS` bits and unconditionally allocating 1-3 `u32`s for that is + // ridiculous (especially since even for small numbers we have to allocate 1 `u32`). + // If this is the future and `BigUint` is better-optimised (maybe by using `Bytes`) then + // forget it. + pub fn parse(header: u64) -> Result { + use num_traits::cast::ToPrimitive; + + let flags = header & FLAG_MASK as u64; + + let substream_id = (header >> FLAG_BITS) + .to_u32() + .ok_or(errors::ErrorKind::ParseError)?; + + // Yes, this is really how it works. No, I don't know why. + let packet_type = match flags { + 0 => PacketType::Open, + + 1 => PacketType::Message(MultiplexEnd::Receiver), + 2 => PacketType::Message(MultiplexEnd::Initiator), + + 3 => PacketType::Close(MultiplexEnd::Receiver), + 4 => PacketType::Close(MultiplexEnd::Initiator), + + 5 => PacketType::Reset(MultiplexEnd::Receiver), + 6 => PacketType::Reset(MultiplexEnd::Initiator), + + _ => { + use std::io; + + return Err(errors::Error::with_chain( + io::Error::new( + io::ErrorKind::Other, + format!("Unexpected packet type: {}", flags), + ), + errors::ErrorKind::ParseError, + )); + } + }; + + Ok(MultiplexHeader { + substream_id, + packet_type, + }) + } + + pub fn to_u64(&self) -> u64 { + let packet_type_id = match self.packet_type { + PacketType::Open => 0, + + PacketType::Message(MultiplexEnd::Receiver) => 1, + PacketType::Message(MultiplexEnd::Initiator) => 2, + + PacketType::Close(MultiplexEnd::Receiver) => 3, + PacketType::Close(MultiplexEnd::Initiator) => 4, + + PacketType::Reset(MultiplexEnd::Receiver) => 5, + PacketType::Reset(MultiplexEnd::Initiator) => 6, + }; + + let substream_id = (self.substream_id as u64) << FLAG_BITS; + + substream_id | packet_type_id + } +} diff --git a/multiplex-rs/src/lib.rs b/multiplex-rs/src/lib.rs index b32f55bb..7e962915 100644 --- a/multiplex-rs/src/lib.rs +++ b/multiplex-rs/src/lib.rs @@ -1,22 +1,56 @@ +// Copyright 2017 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +extern crate arrayvec; extern crate bytes; +#[macro_use] +extern crate error_chain; extern crate futures; -extern crate libp2p_stream_muxer; -extern crate tokio_io; -extern crate varint; +extern crate libp2p_swarm as swarm; extern crate num_bigint; extern crate num_traits; extern crate parking_lot; -extern crate circular_buffer; +extern crate rand; +extern crate tokio_io; +extern crate varint; + +mod read; +mod write; +mod shared; +mod header; use bytes::Bytes; -use circular_buffer::CircularBuffer; -use futures::prelude::*; -use libp2p_stream_muxer::StreamMuxer; +use futures::{Async, Future, Poll}; +use futures::future::{self, FutureResult}; +use header::{MultiplexEnd, MultiplexHeader}; +use swarm::muxing::StreamMuxer; +use swarm::ConnectionUpgrade; use parking_lot::Mutex; -use std::collections::HashMap; +use read::{read_stream, MultiplexReadState}; +use shared::{buf_from_slice, ByteBuf, MultiplexShared}; +use std::iter; use std::io::{self, Read, Write}; use std::sync::Arc; +use std::sync::atomic::{self, AtomicUsize}; use tokio_io::{AsyncRead, AsyncWrite}; +use write::write_stream; // So the multiplex is essentially a distributed finite state machine. // @@ -30,338 +64,55 @@ use tokio_io::{AsyncRead, AsyncWrite}; // In the second state, the substream ID is known. Only this substream can progress until the packet // is consumed. -/// Number of bits used for the metadata on multiplex packets -enum NextMultiplexState { - NewStream(usize), - ParsingMessageBody(usize), - Ignore, -} - -enum MultiplexReadState { - Header { state: varint::DecoderState }, - BodyLength { - state: varint::DecoderState, - next: NextMultiplexState, - }, - NewStream { - substream_id: usize, - name: bytes::BytesMut, - remaining_bytes: usize, - }, - ParsingMessageBody { - substream_id: usize, - remaining_bytes: usize, - }, - Ignore { remaining_bytes: usize }, -} - -impl Default for MultiplexReadState { - fn default() -> Self { - MultiplexReadState::Header { state: Default::default() } - } -} - -struct MultiplexWriteState { - buffer: CircularBuffer<[u8; 1024]>, -} - -// TODO: Add writing. We should also add some form of "pending packet" so that we can always open at -// least one new substream. If this is stored on the substream itself then we can open -// infinite new substreams. -// -// When we've implemented writing, we should send the close message on `Substream` drop. This -// should probably be implemented with some kind of "pending close message" queue. The -// priority should go: -// 1. Open new stream messages -// 2. Regular messages -// 3. Close messages -// Since if we receive a message to a closed stream we just drop it anyway. -struct MultiplexShared { - // We use `Option` in order to take ownership of heap allocations within `DecoderState` and - // `BytesMut`. If this is ever observably `None` then something has panicked or the underlying - // stream returned an error. - read_state: Option, - stream: T, - // true if the stream is open, false otherwise - open_streams: HashMap, - // TODO: Should we use a version of this with a fixed size that doesn't allocate and return - // `WouldBlock` if it's full? - to_open: HashMap, -} - pub struct Substream { - id: usize, + id: u32, + end: MultiplexEnd, name: Option, state: Arc>>, + buffer: Option>, } impl Drop for Substream { fn drop(&mut self) { let mut lock = self.state.lock(); - lock.open_streams.insert(self.id, false); + lock.close_stream(self.id); } } impl Substream { fn new>>( - id: usize, + id: u32, + end: MultiplexEnd, name: B, state: Arc>>, ) -> Self { let name = name.into(); - Substream { id, name, state } + Substream { + id, + end, + name, + state, + buffer: None, + } } pub fn name(&self) -> Option<&Bytes> { self.name.as_ref() } -} -/// This is unsafe because you must ensure that only the `AsyncRead` that was passed in is later -/// used to write to the returned buffer. -unsafe fn create_buffer_for(capacity: usize, inner: &R) -> bytes::BytesMut { - let mut buffer = bytes::BytesMut::with_capacity(capacity); - buffer.set_len(capacity); - inner.prepare_uninitialized_buffer(&mut buffer); - buffer -} - -fn read_stream<'a, O: Into>, T: AsyncRead>( - lock: &mut MultiplexShared, - stream_data: O, -) -> io::Result { - use num_traits::cast::ToPrimitive; - use MultiplexReadState::*; - - let mut stream_data = stream_data.into(); - let stream_has_been_gracefully_closed = stream_data - .as_ref() - .and_then(|&(id, _)| lock.open_streams.get(&id)) - .map(|is_open| !is_open) - .unwrap_or(false); - - let mut on_block: io::Result = if stream_has_been_gracefully_closed { - Ok(0) - } else { - Err(io::Error::from(io::ErrorKind::WouldBlock)) - }; - - loop { - match lock.read_state.take().expect("Logic error or panic") { - Header { state: varint_state } => { - match varint_state.read(&mut lock.stream).map_err(|_| { - io::Error::from(io::ErrorKind::Other) - })? { - Ok(header) => { - let MultiplexHeader { - substream_id, - packet_type, - } = MultiplexHeader::parse(header).map_err(|_| { - io::Error::from(io::ErrorKind::Other) - })?; - - match packet_type { - PacketType::Open => { - lock.read_state = Some(BodyLength { - state: Default::default(), - next: NextMultiplexState::NewStream(substream_id), - }) - } - PacketType::Message(_) => { - lock.read_state = Some(BodyLength { - state: Default::default(), - next: NextMultiplexState::ParsingMessageBody(substream_id), - }) - } - // NOTE: What's the difference between close and reset? - PacketType::Close(_) | - PacketType::Reset(_) => { - lock.read_state = Some(BodyLength { - state: Default::default(), - next: NextMultiplexState::Ignore, - }); - - lock.open_streams.remove(&substream_id); - } - } - } - Err(new_state) => { - lock.read_state = Some(Header { state: new_state }); - return on_block; - } - } - } - BodyLength { - state: varint_state, - next, - } => { - match varint_state.read(&mut lock.stream).map_err(|_| { - io::Error::from(io::ErrorKind::Other) - })? { - Ok(length) => { - use NextMultiplexState::*; - - let length = length.to_usize().ok_or( - io::Error::from(io::ErrorKind::Other), - )?; - - lock.read_state = Some(match next { - Ignore => MultiplexReadState::Ignore { remaining_bytes: length }, - NewStream(substream_id) => MultiplexReadState::NewStream { - // This is safe as long as we only use `lock.stream` to write to - // this field - name: unsafe { create_buffer_for(length, &lock.stream) }, - remaining_bytes: length, - substream_id, - }, - ParsingMessageBody(substream_id) => { - let is_open = lock.open_streams - .get(&substream_id) - .map(|is_open| *is_open) - .unwrap_or(false); - if is_open { - MultiplexReadState::ParsingMessageBody { - remaining_bytes: length, - substream_id, - } - } else { - MultiplexReadState::Ignore { remaining_bytes: length } - } - } - }); - } - Err(new_state) => { - lock.read_state = Some(BodyLength { - state: new_state, - next, - }); - - return on_block; - } - } - } - NewStream { - substream_id, - mut name, - remaining_bytes, - } => { - if remaining_bytes == 0 { - lock.to_open.insert(substream_id, name.freeze()); - - lock.read_state = Some(Default::default()); - } else { - let cursor_pos = name.len() - remaining_bytes; - let consumed = lock.stream.read(&mut name[cursor_pos..]); - - match consumed { - Ok(consumed) => { - let new_remaining = remaining_bytes - consumed; - - lock.read_state = Some(NewStream { - substream_id, - name, - remaining_bytes: new_remaining, - }) - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - lock.read_state = Some(NewStream { - substream_id, - name, - remaining_bytes, - }); - - return on_block; - } - Err(other) => return Err(other), - } - } - } - ParsingMessageBody { - substream_id, - remaining_bytes, - } => { - if let Some((ref mut id, ref mut buf)) = stream_data { - use MultiplexReadState::*; - - if substream_id == *id { - if remaining_bytes == 0 { - lock.read_state = Some(Default::default()); - } else { - let read_result = { - let new_len = buf.len().min(remaining_bytes); - let slice = &mut buf[..new_len]; - - lock.stream.read(slice) - }; - - match read_result { - Ok(consumed) => { - let new_remaining = remaining_bytes - consumed; - - lock.read_state = Some(ParsingMessageBody { - substream_id, - remaining_bytes: new_remaining, - }); - - on_block = Ok(on_block.unwrap_or(0) + consumed); - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - lock.read_state = Some(ParsingMessageBody { - substream_id, - remaining_bytes, - }); - - return on_block; - } - Err(other) => return Err(other), - } - } - } else { - lock.read_state = Some(ParsingMessageBody { - substream_id, - remaining_bytes, - }); - - // We cannot make progress here, another stream has to accept this packet - return on_block; - } - } - } - Ignore { mut remaining_bytes } => { - let mut ignore_buf: [u8; 256] = [0; 256]; - - loop { - if remaining_bytes == 0 { - lock.read_state = Some(Default::default()); - } else { - let new_len = ignore_buf.len().min(remaining_bytes); - match lock.stream.read(&mut ignore_buf[..new_len]) { - Ok(consumed) => remaining_bytes -= consumed, - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - lock.read_state = Some(Ignore { remaining_bytes }); - - return on_block; - } - Err(other) => return Err(other), - } - } - } - } - } + pub fn id(&self) -> u32 { + self.id } } -// TODO: We always zero the buffer, we should delegate to the inner stream. Maybe use a `RWLock` -// instead? +// TODO: We always zero the buffer, we should delegate to the inner stream. impl Read for Substream { - // TODO: Is it wasteful to have all of our substreams try to make progress? Can we use an - // `AtomicBool` or `AtomicUsize` to limit the substreams that try to progress? fn read(&mut self, buf: &mut [u8]) -> io::Result { let mut lock = match self.state.try_lock() { Some(lock) => lock, - None => return Err(io::Error::from(io::ErrorKind::WouldBlock)), + None => return Err(io::ErrorKind::WouldBlock.into()), }; read_stream(&mut lock, (self.id, buf)) @@ -372,91 +123,50 @@ impl AsyncRead for Substream {} impl Write for Substream { fn write(&mut self, buf: &[u8]) -> io::Result { - unimplemented!() + let mut lock = self.state.try_lock().ok_or(io::ErrorKind::WouldBlock)?; + + let mut buffer = self.buffer + .take() + .unwrap_or_else(|| io::Cursor::new(buf_from_slice(buf))); + + let out = write_stream( + &mut *lock, + write::WriteRequest::substream(MultiplexHeader::message(self.id, self.end)), + &mut buffer, + ); + + if buffer.position() < buffer.get_ref().len() as u64 { + self.buffer = Some(buffer); + } + + out } fn flush(&mut self) -> io::Result<()> { - unimplemented!() + self.state + .try_lock() + .ok_or(io::ErrorKind::WouldBlock)? + .stream + .flush() } } impl AsyncWrite for Substream { fn shutdown(&mut self) -> Poll<(), io::Error> { - unimplemented!() + Ok(Async::Ready(())) } } -struct ParseError; - -enum MultiplexEnd { - Initiator, - Receiver, -} - -struct MultiplexHeader { - pub packet_type: PacketType, - pub substream_id: usize, -} -enum PacketType { - Open, - Close(MultiplexEnd), - Reset(MultiplexEnd), - Message(MultiplexEnd), -} - -impl MultiplexHeader { - // TODO: Use `u128` or another large integer type instead of bigint since we never use more than - // `pointer width + FLAG_BITS` bits and unconditionally allocating 1-3 `u32`s for that is - // ridiculous (especially since even for small numbers we have to allocate 1 `u32`). - // If this is the future and `BigUint` is better-optimised (maybe by using `Bytes`) then - // forget it. - fn parse(header: num_bigint::BigUint) -> Result { - use num_traits::cast::ToPrimitive; - - const FLAG_BITS: usize = 3; - - // `&header` to make `>>` produce a new `BigUint` instead of consuming the old `BigUint` - let substream_id = ((&header) >> FLAG_BITS).to_usize().ok_or(ParseError)?; - - let flag_mask = (2usize << FLAG_BITS) - 1; - let flags = header.to_usize().ok_or(ParseError)? & flag_mask; - - // Yes, this is really how it works. No, I don't know why. - let packet_type = match flags { - 0 => PacketType::Open, - - 1 => PacketType::Message(MultiplexEnd::Receiver), - 2 => PacketType::Message(MultiplexEnd::Initiator), - - 3 => PacketType::Close(MultiplexEnd::Receiver), - 4 => PacketType::Close(MultiplexEnd::Initiator), - - 5 => PacketType::Reset(MultiplexEnd::Receiver), - 6 => PacketType::Reset(MultiplexEnd::Initiator), - - _ => return Err(ParseError), - }; - - Ok(MultiplexHeader { - substream_id, - packet_type, - }) - } -} - -pub struct Multiplex { +pub struct InboundFuture { + end: MultiplexEnd, state: Arc>>, } -pub struct InboundStream { - state: Arc>>, -} - -impl Stream for InboundStream { +impl Future for InboundFuture { type Item = Substream; type Error = io::Error; - fn poll(&mut self) -> Poll, Self::Error> { + fn poll(&mut self) -> Poll { let mut lock = match self.state.try_lock() { Some(lock) => lock, None => return Ok(Async::NotReady), @@ -464,8 +174,8 @@ impl Stream for InboundStream { // Attempt to make progress, but don't block if we can't match read_stream(&mut lock, None) { - Ok(_) => (), - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (), + Ok(_) => {} + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {} Err(err) => return Err(err), } @@ -479,30 +189,413 @@ impl Stream for InboundStream { "We just checked that this key exists and we have exclusive access to the map, QED", ); - Ok(Async::Ready( - Some(Substream::new(id, name, self.state.clone())), - )) + lock.open_stream(id); + + Ok(Async::Ready(Substream::new( + id, + self.end, + name, + Arc::clone(&self.state), + ))) + } +} + +pub struct OutboundFuture { + meta: Arc, + current_id: Option<(io::Cursor, u32)>, + state: Arc>>, +} + +impl OutboundFuture { + fn new(muxer: Multiplex) -> Self { + OutboundFuture { + current_id: None, + meta: muxer.meta, + state: muxer.state, + } + } +} + +fn nonce_to_id(id: usize, end: MultiplexEnd) -> u32 { + id as u32 * 2 + if end == MultiplexEnd::Initiator { 1 } else { 0 } +} + +impl Future for OutboundFuture { + type Item = Substream; + type Error = io::Error; + + fn poll(&mut self) -> Poll { + let mut lock = match self.state.try_lock() { + Some(lock) => lock, + None => return Ok(Async::NotReady), + }; + + loop { + let (mut id_str, id) = self.current_id.take().unwrap_or_else(|| { + let next = nonce_to_id( + self.meta.nonce.fetch_add(1, atomic::Ordering::Relaxed), + self.meta.end, + ); + ( + io::Cursor::new(buf_from_slice(format!("{}", next).as_bytes())), + next as u32, + ) + }); + + match write_stream( + &mut *lock, + write::WriteRequest::meta(MultiplexHeader::open(id)), + &mut id_str, + ) { + Ok(_) => { + debug_assert!(id_str.position() <= id_str.get_ref().len() as u64); + if id_str.position() == id_str.get_ref().len() as u64 { + if lock.open_stream(id) { + return Ok(Async::Ready(Substream::new( + id, + self.meta.end, + Bytes::from(&id_str.get_ref()[..]), + Arc::clone(&self.state), + ))); + } + } else { + self.current_id = Some((id_str, id)); + return Ok(Async::NotReady); + } + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + self.current_id = Some((id_str, id)); + + return Ok(Async::NotReady); + } + Err(other) => return Err(other), + } + } + } +} + +pub struct MultiplexMetadata { + nonce: AtomicUsize, + end: MultiplexEnd, +} + +pub struct Multiplex { + meta: Arc, + state: Arc>>, +} + +impl Clone for Multiplex { + fn clone(&self) -> Self { + Multiplex { + meta: self.meta.clone(), + state: self.state.clone(), + } + } +} + +impl Multiplex { + pub fn new(stream: T, end: MultiplexEnd) -> Self { + Multiplex { + meta: Arc::new(MultiplexMetadata { + nonce: AtomicUsize::new(0), + end, + }), + state: Arc::new(Mutex::new(MultiplexShared::new(stream))), + } + } + + pub fn dial(stream: T) -> Self { + Self::new(stream, MultiplexEnd::Initiator) + } + + pub fn listen(stream: T) -> Self { + Self::new(stream, MultiplexEnd::Receiver) } } impl StreamMuxer for Multiplex { type Substream = Substream; - type OutboundSubstreams = Box>; - type InboundSubstreams = InboundStream; + type OutboundSubstream = OutboundFuture; + type InboundSubstream = InboundFuture; - fn inbound(&mut self) -> Self::InboundSubstreams { - InboundStream { state: self.state.clone() } + fn inbound(self) -> Self::InboundSubstream { + InboundFuture { + state: Arc::clone(&self.state), + end: self.meta.end, + } } - fn outbound(&mut self) -> Self::OutboundSubstreams { - unimplemented!() + fn outbound(self) -> Self::OutboundSubstream { + OutboundFuture::new(self) + } +} + +pub struct MultiplexConfig; + +impl ConnectionUpgrade for MultiplexConfig +where + C: AsyncRead + AsyncWrite, +{ + type Output = Multiplex; + type Future = FutureResult, io::Error>; + type UpgradeIdentifier = (); + type NamesIter = iter::Once<(Bytes, ())>; + + #[inline] + fn upgrade(self, i: C, _: ()) -> Self::Future { + future::ok(Multiplex::dial(i)) + } + + #[inline] + fn protocol_names(&self) -> Self::NamesIter { + iter::once((Bytes::from("/mplex/6.7.0"), ())) } } #[cfg(test)] mod tests { + use super::*; + use std::io; + #[test] - fn it_works() { - assert_eq!(2 + 2, 4); + fn can_use_one_stream() { + let message = b"Hello, world!"; + + let stream = io::Cursor::new(Vec::new()); + + let mplex = Multiplex::dial(stream); + + let mut substream = mplex.clone().outbound().wait().unwrap(); + + assert!(substream.write(message).is_ok()); + + let id = substream.id(); + + assert_eq!( + substream + .name() + .and_then(|bytes| { String::from_utf8(bytes.to_vec()).ok() }), + Some(id.to_string()) + ); + + let stream = io::Cursor::new(mplex.state.lock().stream.get_ref().clone()); + + let mplex = Multiplex::listen(stream); + + let mut substream = mplex.inbound().wait().unwrap(); + + assert_eq!(id, substream.id()); + assert_eq!( + substream + .name() + .and_then(|bytes| { String::from_utf8(bytes.to_vec()).ok() }), + Some(id.to_string()) + ); + + let mut buf = vec![0; message.len()]; + + assert!(substream.read(&mut buf).is_ok()); + assert_eq!(&buf, message); + } + + #[test] + fn can_use_many_streams() { + let stream = io::Cursor::new(Vec::new()); + + let mplex = Multiplex::dial(stream); + + let mut outbound: Vec> = vec![ + mplex.clone().outbound().wait().unwrap(), + mplex.clone().outbound().wait().unwrap(), + mplex.clone().outbound().wait().unwrap(), + mplex.clone().outbound().wait().unwrap(), + mplex.clone().outbound().wait().unwrap(), + ]; + + outbound.sort_by_key(|a| a.id()); + + for (i, substream) in outbound.iter_mut().enumerate() { + assert!(substream.write(i.to_string().as_bytes()).is_ok()); + } + + let stream = io::Cursor::new(mplex.state.lock().stream.get_ref().clone()); + + let mplex = Multiplex::listen(stream); + + let mut inbound: Vec> = vec![ + mplex.clone().inbound().wait().unwrap(), + mplex.clone().inbound().wait().unwrap(), + mplex.clone().inbound().wait().unwrap(), + mplex.clone().inbound().wait().unwrap(), + mplex.clone().inbound().wait().unwrap(), + ]; + + inbound.sort_by_key(|a| a.id()); + + for (substream, outbound) in inbound.iter_mut().zip(outbound.iter()) { + let id = outbound.id(); + assert_eq!(id, substream.id()); + assert_eq!( + substream + .name() + .and_then(|bytes| { String::from_utf8(bytes.to_vec()).ok() }), + Some(id.to_string()) + ); + + let mut buf = [0; 3]; + assert_eq!(substream.read(&mut buf).unwrap(), 1); + } + } + + #[test] + fn packets_to_unopened_streams_are_dropped() { + use std::iter; + + let message = b"Hello, world!"; + + // We use a large dummy length to exercise ignoring data longer than `ignore_buffer.len()` + let dummy_length = 1000; + + let input = iter::empty() + // Open a stream + .chain(varint::encode(MultiplexHeader::open(0).to_u64())) + // 0-length body (stream has no name) + .chain(varint::encode(0usize)) + + // "Message"-type packet for an unopened stream + .chain( + varint::encode( + // ID for an unopened stream: 1 + MultiplexHeader::message(1, MultiplexEnd::Initiator).to_u64(), + ).into_iter(), + ) + // Body: `dummy_length` of zeroes + .chain(varint::encode(dummy_length)) + .chain(iter::repeat(0).take(dummy_length)) + + // "Message"-type packet for an opened stream + .chain( + varint::encode( + // ID for an opened stream: 0 + MultiplexHeader::message(0, MultiplexEnd::Initiator).to_u64(), + ).into_iter(), + ) + .chain(varint::encode(message.len())) + .chain(message.iter().cloned()) + + .collect::>(); + + let mplex = Multiplex::listen(io::Cursor::new(input)); + + let mut substream = mplex.inbound().wait().unwrap(); + + assert_eq!(substream.id(), 0); + assert_eq!(substream.name(), None); + + let mut buf = vec![0; message.len()]; + + assert!(substream.read(&mut buf).is_ok()); + assert_eq!(&buf, message); + } + + #[test] + fn can_close_streams() { + use std::iter; + + // Dummy data in the body of the close packet (since the de facto protocol is to accept but + // ignore this data) + let dummy_length = 64; + + let input = iter::empty() + // Open a stream + .chain(varint::encode(MultiplexHeader::open(0).to_u64())) + // 0-length body (stream has no name) + .chain(varint::encode(0usize)) + + // Immediately close the stream + .chain( + varint::encode( + // ID for an unopened stream: 1 + MultiplexHeader::close(0, MultiplexEnd::Initiator).to_u64(), + ).into_iter(), + ) + .chain(varint::encode(dummy_length)) + .chain(iter::repeat(0).take(dummy_length)) + + // Send packet to the closed stream + .chain( + varint::encode( + // ID for an opened stream: 0 + MultiplexHeader::message(0, MultiplexEnd::Initiator).to_u64(), + ).into_iter(), + ) + .chain(varint::encode(dummy_length)) + .chain(iter::repeat(0).take(dummy_length)) + + .collect::>(); + + let mplex = Multiplex::listen(io::Cursor::new(input)); + + let mut substream = mplex.inbound().wait().unwrap(); + + assert_eq!(substream.id(), 0); + assert_eq!(substream.name(), None); + + assert_eq!(substream.read(&mut [0; 100][..]).unwrap(), 0); + } + + #[test] + fn real_world_data() { + let data: Vec = vec![ + // Open stream 1 + 8, + 0, + + // Message for stream 1 (length 20) + 10, + 20, + 19, + 47, + 109, + 117, + 108, + 116, + 105, + 115, + 116, + 114, + 101, + 97, + 109, + 47, + 49, + 46, + 48, + 46, + 48, + 10, + ]; + + let mplex = Multiplex::listen(io::Cursor::new(data)); + + let mut substream = mplex.inbound().wait().unwrap(); + + assert_eq!(substream.id(), 1); + + assert_eq!(substream.name(), None); + + let mut out = vec![]; + + for _ in 0..20 { + let mut buf = [0; 1]; + + assert_eq!(substream.read(&mut buf[..]).unwrap(), 1); + + out.push(buf[0]); + } + + assert_eq!(out[0], 19); + assert_eq!(&out[1..0x14 - 1], b"/multistream/1.0.0"); + assert_eq!(out[0x14 - 1], 0x0a); } } diff --git a/multiplex-rs/src/read.rs b/multiplex-rs/src/read.rs new file mode 100644 index 00000000..5b67649b --- /dev/null +++ b/multiplex-rs/src/read.rs @@ -0,0 +1,402 @@ +// Copyright 2017 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use {bytes, varint}; +use futures::Async; +use futures::task; +use header::{MultiplexHeader, PacketType}; +use std::io; +use tokio_io::AsyncRead; +use shared::SubstreamMetadata; + +pub enum NextMultiplexState { + NewStream(u32), + ParsingMessageBody(u32), + Ignore, +} + +pub enum MultiplexReadState { + Header { + state: varint::DecoderState, + }, + BodyLength { + state: varint::DecoderState, + next: NextMultiplexState, + }, + NewStream { + substream_id: u32, + name: bytes::BytesMut, + remaining_bytes: usize, + }, + ParsingMessageBody { + substream_id: u32, + remaining_bytes: usize, + }, + Ignore { + remaining_bytes: usize, + }, +} + +impl Default for MultiplexReadState { + fn default() -> Self { + MultiplexReadState::Header { + state: Default::default(), + } + } +} + +fn create_buffer(capacity: usize) -> bytes::BytesMut { + let mut buffer = bytes::BytesMut::with_capacity(capacity); + let zeroes = [0; 1024]; + let mut cap = capacity; + + while cap > 0 { + let len = cap.min(zeroes.len()); + buffer.extend_from_slice(&zeroes[..len]); + cap -= len; + } + + buffer +} + +pub fn read_stream<'a, O: Into>, T: AsyncRead>( + lock: &mut ::shared::MultiplexShared, + stream_data: O, +) -> io::Result { + use self::MultiplexReadState::*; + + let mut stream_data = stream_data.into(); + let stream_has_been_gracefully_closed = stream_data + .as_ref() + .and_then(|&(id, _)| lock.open_streams.get(&id)) + .map(|meta| !meta.open()) + .unwrap_or(false); + + let mut on_block: io::Result = if stream_has_been_gracefully_closed { + Ok(0) + } else { + Err(io::ErrorKind::WouldBlock.into()) + }; + + loop { + match lock.read_state.take().unwrap_or_default() { + Header { + state: mut varint_state, + } => { + match varint_state.read(&mut lock.stream) { + Ok(Async::Ready(header)) => { + let header = if let Some(header) = header { + header + } else { + return Ok(0); + }; + + let MultiplexHeader { + substream_id, + packet_type, + } = MultiplexHeader::parse(header).map_err(|err| { + io::Error::new( + io::ErrorKind::Other, + format!("Error parsing header: {:?}", err), + ) + })?; + + match packet_type { + PacketType::Open => { + lock.read_state = Some(BodyLength { + state: Default::default(), + next: NextMultiplexState::NewStream(substream_id), + }) + } + PacketType::Message(_) => { + lock.read_state = Some(BodyLength { + state: Default::default(), + next: NextMultiplexState::ParsingMessageBody(substream_id), + }) + } + // NOTE: What's the difference between close and reset? + PacketType::Close(_) | PacketType::Reset(_) => { + lock.read_state = Some(BodyLength { + state: Default::default(), + next: NextMultiplexState::Ignore, + }); + + lock.close_stream(substream_id); + } + } + } + Ok(Async::NotReady) => { + lock.read_state = Some(Header { + state: varint_state, + }); + return on_block; + } + Err(error) => { + return if let varint::Error(varint::ErrorKind::Io(inner), ..) = error { + Err(inner) + } else { + Err(io::Error::new(io::ErrorKind::Other, error.description())) + }; + } + } + } + BodyLength { + state: mut varint_state, + next, + } => { + use self::NextMultiplexState::*; + + match varint_state + .read(&mut lock.stream) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "Error reading varint"))? + { + Async::Ready(length) => { + // TODO: Limit `length` to prevent resource-exhaustion DOS + let length = if let Some(length) = length { + length + } else { + return Ok(0); + }; + + lock.read_state = match next { + Ignore => Some(MultiplexReadState::Ignore { + remaining_bytes: length, + }), + NewStream(substream_id) => { + if length == 0 { + lock.to_open.insert(substream_id, None); + + None + } else { + Some(MultiplexReadState::NewStream { + // TODO: Uninit buffer + name: create_buffer(length), + remaining_bytes: length, + substream_id, + }) + } + } + ParsingMessageBody(substream_id) => { + let is_open = lock.open_streams + .get(&substream_id) + .map(SubstreamMetadata::open) + .unwrap_or_else(|| lock.to_open.contains_key(&substream_id)); + + if is_open { + Some(MultiplexReadState::ParsingMessageBody { + remaining_bytes: length, + substream_id, + }) + } else { + Some(MultiplexReadState::Ignore { + remaining_bytes: length, + }) + } + } + }; + } + Async::NotReady => { + lock.read_state = Some(BodyLength { + state: varint_state, + next, + }); + + return on_block; + } + } + } + NewStream { + substream_id, + mut name, + remaining_bytes, + } => { + if remaining_bytes == 0 { + lock.to_open.insert(substream_id, Some(name.freeze())); + + lock.read_state = None; + } else { + let cursor_pos = name.len() - remaining_bytes; + let consumed = lock.stream.read(&mut name[cursor_pos..]); + + match consumed { + Ok(consumed) => { + let new_remaining = remaining_bytes - consumed; + + lock.read_state = Some(NewStream { + substream_id, + name, + remaining_bytes: new_remaining, + }); + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + lock.read_state = Some(NewStream { + substream_id, + name, + remaining_bytes, + }); + + return on_block; + } + Err(other) => { + lock.read_state = Some(NewStream { + substream_id, + name, + remaining_bytes, + }); + + return Err(other); + } + } + } + } + ParsingMessageBody { + substream_id, + remaining_bytes, + } => { + if let Some((ref mut id, ref mut buf)) = stream_data { + use MultiplexReadState::*; + + if remaining_bytes == 0 { + lock.read_state = None; + + return on_block; + } else if substream_id == *id { + let number_read = *on_block.as_ref().unwrap_or(&0); + + if buf.len() == 0 { + return Ok(0); + } else if number_read >= buf.len() { + lock.read_state = Some(ParsingMessageBody { + substream_id, + remaining_bytes, + }); + + return on_block; + } + + let read_result = { + // We know this won't panic because of the earlier + // `number_read >= buf.len()` check + let new_len = (buf.len() - number_read).min(remaining_bytes); + let slice = &mut buf[number_read..number_read + new_len]; + + lock.stream.read(slice) + }; + + match read_result { + Ok(consumed) => { + let new_remaining = remaining_bytes - consumed; + + lock.read_state = Some(ParsingMessageBody { + substream_id, + remaining_bytes: new_remaining, + }); + + on_block = Ok(number_read + consumed); + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + lock.read_state = Some(ParsingMessageBody { + substream_id, + remaining_bytes, + }); + + return on_block; + } + Err(other) => { + lock.read_state = Some(ParsingMessageBody { + substream_id, + remaining_bytes, + }); + + return Err(other); + } + } + } else { + lock.read_state = Some(ParsingMessageBody { + substream_id, + remaining_bytes, + }); + + if let Some(task) = lock.open_streams + .get(&substream_id) + .and_then(SubstreamMetadata::read_task) + { + task.notify(); + } + + let write = lock.open_streams + .get(id) + .and_then(SubstreamMetadata::write_task) + .cloned(); + lock.open_streams.insert( + *id, + SubstreamMetadata::Open { + read: Some(task::current()), + write, + }, + ); + // We cannot make progress here, another stream has to accept this packet + return on_block; + } + } else { + lock.read_state = Some(ParsingMessageBody { + substream_id, + remaining_bytes, + }); + + // We cannot make progress here, a stream has to accept this packet + return on_block; + } + } + Ignore { + mut remaining_bytes, + } => { + let mut ignore_buf: [u8; 256] = [0; 256]; + + loop { + if remaining_bytes == 0 { + lock.read_state = None; + break; + } else { + let new_len = ignore_buf.len().min(remaining_bytes); + match lock.stream.read(&mut ignore_buf[..new_len]) { + Ok(consumed) => { + remaining_bytes -= consumed; + lock.read_state = Some(Ignore { + remaining_bytes: remaining_bytes, + }); + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + lock.read_state = Some(Ignore { remaining_bytes }); + + return on_block; + } + Err(other) => { + lock.read_state = Some(Ignore { remaining_bytes }); + + return Err(other); + } + } + } + } + } + } + } +} diff --git a/multiplex-rs/src/shared.rs b/multiplex-rs/src/shared.rs new file mode 100644 index 00000000..172bfd34 --- /dev/null +++ b/multiplex-rs/src/shared.rs @@ -0,0 +1,108 @@ +// Copyright 2017 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use read::MultiplexReadState; +use write::MultiplexWriteState; + +use std::collections::HashMap; +use bytes::Bytes; +use arrayvec::ArrayVec; +use futures::task::Task; + +const BUF_SIZE: usize = 1024; + +pub type ByteBuf = ArrayVec<[u8; BUF_SIZE]>; + +pub enum SubstreamMetadata { + Closed, + Open { + read: Option, + write: Option, + }, +} + +impl SubstreamMetadata { + pub fn open(&self) -> bool { + match *self { + SubstreamMetadata::Closed => false, + SubstreamMetadata::Open { .. } => true, + } + } + + pub fn read_task(&self) -> Option<&Task> { + match *self { + SubstreamMetadata::Closed => None, + SubstreamMetadata::Open { ref read, .. } => read.as_ref(), + } + } + + pub fn write_task(&self) -> Option<&Task> { + match *self { + SubstreamMetadata::Closed => None, + SubstreamMetadata::Open { ref write, .. } => write.as_ref(), + } + } +} + +// TODO: Split reading and writing into different structs and have information shared between the +// two in a `RwLock`, since `open_streams` and `to_open` are mostly read-only. +pub struct MultiplexShared { + // We use `Option` in order to take ownership of heap allocations within `DecoderState` and + // `BytesMut`. If this is ever observably `None` then something has panicked or the underlying + // stream returned an error. + pub read_state: Option, + pub write_state: Option, + pub stream: T, + // true if the stream is open, false otherwise + pub open_streams: HashMap, + // TODO: Should we use a version of this with a fixed size that doesn't allocate and return + // `WouldBlock` if it's full? + pub to_open: HashMap>, +} + +impl MultiplexShared { + pub fn new(stream: T) -> Self { + MultiplexShared { + read_state: Default::default(), + write_state: Default::default(), + open_streams: Default::default(), + to_open: Default::default(), + stream: stream, + } + } + + pub fn open_stream(&mut self, id: u32) -> bool { + self.open_streams + .entry(id) + .or_insert(SubstreamMetadata::Open { + read: None, + write: None, + }) + .open() + } + + pub fn close_stream(&mut self, id: u32) { + self.open_streams.insert(id, SubstreamMetadata::Closed); + } +} + +pub fn buf_from_slice(slice: &[u8]) -> ByteBuf { + slice.iter().cloned().take(BUF_SIZE).collect() +} diff --git a/multiplex-rs/src/write.rs b/multiplex-rs/src/write.rs new file mode 100644 index 00000000..3b702a67 --- /dev/null +++ b/multiplex-rs/src/write.rs @@ -0,0 +1,188 @@ +// Copyright 2017 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use shared::{ByteBuf, MultiplexShared, SubstreamMetadata}; +use header::MultiplexHeader; + +use varint; +use futures::task; +use std::io; +use tokio_io::AsyncWrite; + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub enum RequestType { + Meta, + Substream, +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +pub struct WriteRequest { + header: MultiplexHeader, + request_type: RequestType, +} + +impl WriteRequest { + pub fn substream(header: MultiplexHeader) -> Self { + WriteRequest { + header, + request_type: RequestType::Substream, + } + } + + pub fn meta(header: MultiplexHeader) -> Self { + WriteRequest { + header, + request_type: RequestType::Meta, + } + } +} + +#[derive(Default, Debug)] +pub struct MultiplexWriteState { + current: Option<(WriteRequest, MultiplexWriteStateInner)>, + queued: Option, + // TODO: Actually close these + to_close: Vec, +} + +#[derive(Debug)] +pub enum MultiplexWriteStateInner { + WriteHeader { state: varint::EncoderState }, + BodyLength { state: varint::EncoderState }, + Body { size: usize }, +} + +pub fn write_stream( + lock: &mut MultiplexShared, + write_request: WriteRequest, + buf: &mut io::Cursor, +) -> io::Result { + use futures::Async; + use num_traits::cast::ToPrimitive; + use varint::WriteState; + use write::MultiplexWriteStateInner::*; + + let mut on_block = Err(io::ErrorKind::WouldBlock.into()); + let mut write_state = lock.write_state.take().unwrap_or_default(); + let (request, mut state) = write_state.current.take().unwrap_or_else(|| { + ( + write_request, + MultiplexWriteStateInner::WriteHeader { + state: varint::EncoderState::new(write_request.header.to_u64()), + }, + ) + }); + + let id = write_request.header.substream_id; + + match (request.request_type, write_request.request_type) { + (RequestType::Substream, RequestType::Substream) if request.header.substream_id != id => { + let read = lock.open_streams + .get(&id) + .and_then(SubstreamMetadata::read_task) + .cloned(); + + if let Some(task) = lock.open_streams + .get(&request.header.substream_id) + .and_then(SubstreamMetadata::write_task) + { + task.notify(); + } + + lock.open_streams.insert( + id, + SubstreamMetadata::Open { + write: Some(task::current()), + read, + }, + ); + lock.write_state = Some(write_state); + return on_block; + } + (RequestType::Substream, RequestType::Meta) + | (RequestType::Meta, RequestType::Substream) => { + lock.write_state = Some(write_state); + return on_block; + } + _ => {} + } + + loop { + // Err = should return, Ok = continue + let new_state = match state { + WriteHeader { + state: mut inner_state, + } => match inner_state + .write(&mut lock.stream) + .map_err(|_| io::ErrorKind::Other)? + { + Async::Ready(WriteState::Done(_)) => Ok(BodyLength { + state: varint::EncoderState::new(buf.get_ref().len()), + }), + Async::Ready(WriteState::Pending(_)) | Async::NotReady => { + Err(Some(WriteHeader { state: inner_state })) + } + }, + BodyLength { + state: mut inner_state, + } => match inner_state + .write(&mut lock.stream) + .map_err(|_| io::ErrorKind::Other)? + { + Async::Ready(WriteState::Done(_)) => Ok(Body { + size: inner_state.source().to_usize().unwrap_or(::std::usize::MAX), + }), + Async::Ready(WriteState::Pending(_)) => Ok(BodyLength { state: inner_state }), + Async::NotReady => Err(Some(BodyLength { state: inner_state })), + }, + Body { size } => { + if buf.position() == buf.get_ref().len() as u64 { + Err(None) + } else { + match lock.stream.write(&buf.get_ref()[buf.position() as usize..]) { + Ok(just_written) => { + let cur_pos = buf.position(); + buf.set_position(cur_pos + just_written as u64); + on_block = Ok(on_block.unwrap_or(0) + just_written); + Ok(Body { + size: size - just_written, + }) + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + Err(Some(Body { size })) + } + Err(other) => { + return Err(other); + } + } + } + } + }; + + match new_state { + Ok(new_state) => state = new_state, + Err(new_state) => { + write_state.current = new_state.map(|state| (request, state)); + lock.write_state = Some(write_state); + return on_block; + } + } + } +} diff --git a/varint-rs/Cargo.toml b/varint-rs/Cargo.toml index ffe819b0..fb2fca18 100644 --- a/varint-rs/Cargo.toml +++ b/varint-rs/Cargo.toml @@ -5,6 +5,8 @@ authors = ["Parity Technologies "] [dependencies] num-bigint = "0.1.40" +num-traits = "0.1.40" bytes = "0.4.5" tokio-io = "0.1" futures = "0.1" +error-chain = "0.11.0" diff --git a/varint-rs/src/lib.rs b/varint-rs/src/lib.rs index f126536b..9b9c521d 100644 --- a/varint-rs/src/lib.rs +++ b/varint-rs/src/lib.rs @@ -1,134 +1,404 @@ +#![warn(missing_docs)] + +//! Encoding and decoding state machines for protobuf varints + +// TODO: Non-allocating `BigUint`? extern crate num_bigint; +extern crate num_traits; extern crate tokio_io; extern crate bytes; extern crate futures; +#[macro_use] +extern crate error_chain; use bytes::BytesMut; +use futures::{Poll, Async}; use num_bigint::BigUint; -use tokio_io::AsyncRead; +use num_traits::ToPrimitive; +use tokio_io::{AsyncRead, AsyncWrite}; use tokio_io::codec::Decoder; use std::io; use std::io::prelude::*; -// TODO: error-chain -pub struct ParseError; +mod errors { + error_chain! { + errors { + ParseError { + description("error parsing varint") + display("error parsing varint") + } + WriteError { + description("error writing varint") + display("error writing varint") + } + } -#[derive(Default)] -pub struct DecoderState { - // TODO: Non-allocating `BigUint`? - accumulator: BigUint, - shift: usize, -} - -impl DecoderState { - pub fn new() -> Self { - Default::default() - } - - fn decode_one(mut self, byte: u8) -> Result { - self.accumulator = self.accumulator | (BigUint::from(byte & 0x7F) << self.shift); - self.shift += 7; - - if byte & 0x80 == 0 { - Ok(self.accumulator) - } else { - Err(self) + foreign_links { + Io(::std::io::Error); } } +} - // Why the weird type signature? Well, `BigUint` owns its storage, and we don't want to clone - // it. So, we want the accumulator to be moved out when it is ready. We could have also used - // `Option`, but this means that it's not possible to end up in an inconsistent state - // (`shift != 0 && accumulator.is_none()`). - pub fn read(self, mut input: R) -> Result, ParseError> { - let mut state = self; - loop { - // We read one at a time to prevent consuming too much of the buffer. - let mut buffer: [u8; 1] = [0]; +pub use errors::{Error, ErrorKind}; - match input.read_exact(&mut buffer) { - Ok(_) => { - state = match state.decode_one(buffer[0]) { - Ok(out) => break Ok(Ok(out)), - Err(state) => state, - }; - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => break Ok(Err(state)), - Err(_) => break Err(ParseError), +const USABLE_BITS_PER_BYTE: usize = 7; + +/// The state struct for the varint-to-bytes FSM +#[derive(Debug)] +pub struct EncoderState { + source: T, + // A "chunk" is a section of the contained `BigUint` `USABLE_BITS_PER_BYTE` bits long + num_chunks: usize, + cur_chunk: usize, +} + +/// Whether or not the varint writing was completed +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +pub enum WriteState { + /// The encoder has finished writing + Done(usize), + /// The encoder still must write more bytes + Pending(usize), +} + +fn ceil_div(a: usize, b: usize) -> usize { + (a + b - 1) / b +} + +/// A trait to get the minimum number of bits required to represent a number +pub trait Bits { + /// The minimum number of bits required to represent `self` + fn bits(&self) -> usize; +} + +impl Bits for BigUint { + fn bits(&self) -> usize { + BigUint::bits(self) + } +} + +macro_rules! impl_bits { + ($t:ty) => { + impl Bits for $t { + fn bits(&self) -> usize { + (std::mem::size_of::<$t>() * 8) - self.leading_zeros() as usize } } } } -#[derive(Default)] -pub struct VarintDecoder { - state: Option, +impl_bits!(usize); +impl_bits!(u64); +impl_bits!(u32); +impl_bits!(u16); +impl_bits!(u8); + +/// Helper trait to allow multiple integer types to be encoded +pub trait EncoderHelper: Sized { + /// Write as much as possible of the inner integer to the output `AsyncWrite` + fn write(encoder: &mut EncoderState, output: W) + -> Poll; } -impl VarintDecoder { +/// Helper trait to allow multiple integer types to be encoded +pub trait DecoderHelper: Sized { + /// Decode a single byte + fn decode_one(decoder: &mut DecoderState, byte: u8) -> Option; + + /// Read as much of the varint as possible + fn read(decoder: &mut DecoderState, input: R) -> Poll, Error>; +} + +macro_rules! impl_decoderstate { + ($t:ty) => { impl_decoderstate!($t, <$t>::from, |v| v); }; + ($t:ty, $make_fn:expr) => { impl_decoderstate!($t, $make_fn, $make_fn); }; + ($t:ty, $make_fn:expr, $make_shift_fn:expr) => { + impl DecoderHelper for $t { + #[inline] + fn decode_one(decoder: &mut DecoderState, byte: u8) -> Option<$t> { + decoder.accumulator.take().and_then(|accumulator| { + let out = accumulator | + ( + $make_fn(byte & 0x7F) << + $make_shift_fn(decoder.shift * USABLE_BITS_PER_BYTE) + ); + decoder.shift += 1; + + if byte & 0x80 == 0 { + Some(out) + } else { + decoder.accumulator = AccumulatorState::InProgress(out); + None + } + }) + } + + fn read( + decoder: &mut DecoderState, + mut input: R + ) -> Poll, Error> { + if decoder.accumulator == AccumulatorState::Finished { + return Err(Error::with_chain( + io::Error::new( + io::ErrorKind::Other, + "Attempted to parse a second varint (create a new instance!)", + ), + ErrorKind::ParseError, + )); + } + + loop { + // We read one at a time to prevent consuming too much of the buffer. + let mut buffer: [u8; 1] = [0]; + + match input.read_exact(&mut buffer) { + Ok(()) => { + if let Some(out) = Self::decode_one(decoder, buffer[0]) { + break Ok(Async::Ready(Some(out))); + } + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + break Ok(Async::NotReady); + } + Err(inner) => if decoder.accumulator == AccumulatorState::NotStarted { + break Ok(Async::Ready(None)); + } else { + break Err(Error::with_chain(inner, ErrorKind::ParseError)) + }, + } + } + } + } + } +} + +macro_rules! impl_encoderstate { + ($t:ty) => { impl_encoderstate!($t, <$t>::from); }; + ($t:ty, $make_fn:expr) => { + impl EncoderHelper for $t { + /// Write as much as possible of the inner integer to the output `AsyncWrite` + fn write( + encoder: &mut EncoderState, + mut output: W, + ) -> Poll { + fn encode_one(encoder: &EncoderState<$t>) -> Option { + let last_chunk = encoder.num_chunks - 1; + + if encoder.cur_chunk > last_chunk { + return None; + } + + let masked = (&encoder.source >> (encoder.cur_chunk * USABLE_BITS_PER_BYTE)) & + $make_fn((1 << USABLE_BITS_PER_BYTE) - 1usize); + let masked = masked.to_u8().expect( + "Masked with 0b0111_1111, is less than u8::MAX, QED", + ); + + if encoder.cur_chunk == last_chunk { + Some(masked) + } else { + Some(masked | (1 << USABLE_BITS_PER_BYTE)) + } + } + + let mut written = 0usize; + + loop { + if let Some(byte) = encode_one(&encoder) { + let buffer: [u8; 1] = [byte]; + + match output.write_all(&buffer) { + Ok(()) => { + written += 1; + encoder.cur_chunk += 1; + } + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + break if written == 0 { + Ok(Async::NotReady) + } else { + Ok(Async::Ready(WriteState::Pending(written))) + }; + } + Err(inner) => break Err( + Error::with_chain(inner, ErrorKind::WriteError) + ), + } + } else { + break Ok(Async::Ready(WriteState::Done(written))); + } + } + } + } + } +} + +impl_encoderstate!(usize); +impl_encoderstate!(BigUint); +impl_encoderstate!(u64, (|val| val as u64)); +impl_encoderstate!(u32, (|val| val as u32)); + +impl_decoderstate!(usize); +impl_decoderstate!(BigUint); +impl_decoderstate!(u64, (|val| val as u64)); +impl_decoderstate!(u32, (|val| val as u32)); + +impl EncoderState { + pub fn source(&self) -> &T { + &self.source + } +} + +impl EncoderState { + /// Create a new encoder + pub fn new(inner: T) -> Self { + let bits = inner.bits(); + EncoderState { + source: inner, + num_chunks: ceil_div(bits, USABLE_BITS_PER_BYTE).max(1), + cur_chunk: 0, + } + } +} + +impl EncoderState { + /// Write as much as possible of the inner integer to the output `AsyncWrite` + pub fn write(&mut self, output: W) -> Poll { + T::write(self, output) + } +} + +#[derive(Debug, PartialEq, Eq)] +enum AccumulatorState { + InProgress(T), + NotStarted, + Finished, +} + +impl AccumulatorState { + fn take(&mut self) -> Option { + use std::mem; + use AccumulatorState::*; + + match mem::replace(self, AccumulatorState::Finished) { + InProgress(inner) => Some(inner), + NotStarted => Some(Default::default()), + Finished => None, + } + } +} + +/// The state struct for the varint bytes-to-bigint FSM +#[derive(Debug)] +pub struct DecoderState { + accumulator: AccumulatorState, + shift: usize, +} + +impl Default for DecoderState { + fn default() -> Self { + DecoderState { + accumulator: AccumulatorState::NotStarted, + shift: 0, + } + } +} + +impl DecoderState { + /// Make a new decoder pub fn new() -> Self { Default::default() } } -impl Decoder for VarintDecoder { - type Item = BigUint; +impl DecoderState { + /// Make a new decoder + pub fn read(&mut self, input: R) -> Poll, Error> { + T::read(self, input) + } +} + +/// Wrapper around `DecoderState` to make a `tokio` `Decoder` +#[derive(Debug)] +pub struct VarintDecoder { + state: Option>, +} + +impl Default for VarintDecoder { + fn default() -> Self { + VarintDecoder { state: None } + } +} + +impl VarintDecoder { + /// Make a new decoder + pub fn new() -> Self { + Default::default() + } +} + +impl Decoder for VarintDecoder { + type Item = T; type Error = io::Error; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { loop { - if src.len() == 0 { + if src.is_empty() && self.state.is_some() { break Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } else { // We know that the length is not 0, so this cannot fail. let first_byte = src.split_to(1)[0]; - let new_state = match self.state.take() { - Some(state) => state.decode_one(first_byte), - None => DecoderState::new().decode_one(first_byte), - }; + let mut state = self.state.take().unwrap_or_default(); + let out = T::decode_one(&mut state, first_byte); - match new_state { - Ok(out) => break Ok(Some(out)), - Err(state) => self.state = Some(state), + if let Some(out) = out { + break Ok(Some(out)); + } else { + self.state = Some(state); } } } } } -pub fn decode(stream: R) -> io::Result { - let mut out = BigUint::from(0u8); - let mut shift = 0; - let mut finished_cleanly = false; +/// Syncronously decode a number from a `Read` +pub fn decode(mut input: R) -> errors::Result { + let mut decoder = DecoderState::default(); - for i in stream.bytes() { - let i = i?; + loop { + // We read one at a time to prevent consuming too much of the buffer. + let mut buffer: [u8; 1] = [0]; - out = out | (BigUint::from(i & 0x7F) << shift); - shift += 7; - - if i & 0x80 == 0 { - finished_cleanly = true; - break; + match input.read_exact(&mut buffer) { + Ok(()) => { + if let Some(out) = T::decode_one(&mut decoder, buffer[0]) { + break Ok(out); + } + } + Err(inner) => break Err(Error::with_chain(inner, ErrorKind::ParseError)), } } +} - if finished_cleanly { - Ok(out) - } else { - Err(io::Error::from(io::ErrorKind::UnexpectedEof)) +/// Syncronously decode a number from a `Read` +pub fn encode(input: T) -> Vec { + let mut encoder = EncoderState::new(input); + let mut out = io::Cursor::new(Vec::with_capacity(1)); + + match T::write(&mut encoder, &mut out).expect("Writing to a vec should never fail, Q.E.D") { + Async::Ready(_) => out.into_inner(), + Async::NotReady => unreachable!(), } } #[cfg(test)] mod tests { - use super::{decode, VarintDecoder}; + use super::{decode, VarintDecoder, EncoderState}; use tokio_io::codec::FramedRead; use num_bigint::BigUint; use futures::{Future, Stream}; #[test] - fn can_decode_basic_uint() { + fn can_decode_basic_biguint() { assert_eq!( BigUint::from(300u16), decode(&[0b10101100, 0b00000010][..]).unwrap() @@ -136,7 +406,7 @@ mod tests { } #[test] - fn can_decode_basic_uint_async() { + fn can_decode_basic_biguint_async() { let result = FramedRead::new(&[0b10101100, 0b00000010][..], VarintDecoder::new()) .into_future() .map(|(out, _)| out) @@ -146,11 +416,77 @@ mod tests { assert_eq!(result.unwrap(), Some(BigUint::from(300u16))); } + #[test] + fn can_decode_trivial_usize_async() { + let result = FramedRead::new(&[1][..], VarintDecoder::new()) + .into_future() + .map(|(out, _)| out) + .map_err(|(out, _)| out) + .wait(); + + assert_eq!(result.unwrap(), Some(1usize)); + } + + #[test] + fn can_decode_basic_usize_async() { + let result = FramedRead::new(&[0b10101100, 0b00000010][..], VarintDecoder::new()) + .into_future() + .map(|(out, _)| out) + .map_err(|(out, _)| out) + .wait(); + + assert_eq!(result.unwrap(), Some(300usize)); + } + + #[test] + fn can_encode_basic_biguint_async() { + use std::io::Cursor; + use futures::Async; + use super::WriteState; + + let mut out = vec![0u8; 2]; + + { + let writable: Cursor<&mut [_]> = Cursor::new(&mut out); + + let mut state = EncoderState::new(BigUint::from(300usize)); + + assert_eq!( + state.write(writable).unwrap(), + Async::Ready(WriteState::Done(2)) + ); + } + + assert_eq!(out, vec![0b10101100, 0b00000010]); + } + + #[test] + fn can_encode_basic_usize_async() { + use std::io::Cursor; + use futures::Async; + use super::WriteState; + + let mut out = vec![0u8; 2]; + + { + let writable: Cursor<&mut [_]> = Cursor::new(&mut out); + + let mut state = EncoderState::new(300usize); + + assert_eq!( + state.write(writable).unwrap(), + Async::Ready(WriteState::Done(2)) + ); + } + + assert_eq!(out, vec![0b10101100, 0b00000010]); + } + #[test] fn unexpected_eof_async() { use std::io; - let result = FramedRead::new(&[0b10101100, 0b10000010][..], VarintDecoder::new()) + let result = FramedRead::new(&[0b10101100, 0b10000010][..], VarintDecoder::::new()) .into_future() .map(|(out, _)| out) .map_err(|(out, _)| out)