diff --git a/core/tests/multiplex.rs b/core/tests/multiplex.rs index 7f03abcb..dbb4f123 100644 --- a/core/tests/multiplex.rs +++ b/core/tests/multiplex.rs @@ -78,7 +78,7 @@ fn client_to_server_outbound() { let bg_thread = thread::spawn(move || { let transport = TcpConfig::new() - .with_upgrade(multiplex::MultiplexConfig::new()) + .with_upgrade(multiplex::MplexConfig::new()) .into_connection_reuse(); let (listener, addr) = transport @@ -107,7 +107,7 @@ fn client_to_server_outbound() { tokio_current_thread::block_on_all(future).unwrap(); }); - let transport = TcpConfig::new().with_upgrade(multiplex::MultiplexConfig::new()); + let transport = TcpConfig::new().with_upgrade(multiplex::MplexConfig::new()); let future = transport .dial(rx.recv().unwrap()) @@ -130,7 +130,7 @@ fn connection_reused_for_dialing() { let bg_thread = thread::spawn(move || { let transport = OnlyOnce::from(TcpConfig::new()) - .with_upgrade(multiplex::MultiplexConfig::new()) + .with_upgrade(multiplex::MplexConfig::new()) .into_connection_reuse(); let (listener, addr) = transport @@ -171,7 +171,7 @@ fn connection_reused_for_dialing() { }); let transport = OnlyOnce::from(TcpConfig::new()) - .with_upgrade(multiplex::MultiplexConfig::new()) + .with_upgrade(multiplex::MplexConfig::new()) .into_connection_reuse(); let listen_addr = rx.recv().unwrap(); @@ -207,7 +207,7 @@ fn use_opened_listen_to_dial() { let bg_thread = thread::spawn(move || { let transport = OnlyOnce::from(TcpConfig::new()) - .with_upgrade(multiplex::MultiplexConfig::new()); + .with_upgrade(multiplex::MplexConfig::new()); let (listener, addr) = transport .clone() @@ -248,7 +248,7 @@ fn use_opened_listen_to_dial() { }); let transport = OnlyOnce::from(TcpConfig::new()) - .with_upgrade(multiplex::MultiplexConfig::new()) + .with_upgrade(multiplex::MplexConfig::new()) .into_connection_reuse(); let listen_addr = rx.recv().unwrap(); diff --git a/libp2p/examples/echo-dialer.rs b/libp2p/examples/echo-dialer.rs index 0e43fd41..18e71a58 100644 --- a/libp2p/examples/echo-dialer.rs +++ b/libp2p/examples/echo-dialer.rs @@ -70,8 +70,8 @@ fn main() { }) // On top of plaintext or secio, we will use the multiplex protocol. - .with_upgrade(libp2p::mplex::MultiplexConfig::new()) - // The object returned by the call to `with_upgrade(MultiplexConfig::new())` can't be used as a + .with_upgrade(libp2p::mplex::MplexConfig::new()) + // The object returned by the call to `with_upgrade(MplexConfig::new())` can't be used as a // `Transport` because the output of the upgrade is not a stream but a controller for // muxing. We have to explicitly call `into_connection_reuse()` in order to turn this into // a `Transport`. diff --git a/libp2p/examples/echo-server.rs b/libp2p/examples/echo-server.rs index 1822cce5..4ad06e66 100644 --- a/libp2p/examples/echo-server.rs +++ b/libp2p/examples/echo-server.rs @@ -70,8 +70,8 @@ fn main() { }) // On top of plaintext or secio, we will use the multiplex protocol. - .with_upgrade(libp2p::mplex::MultiplexConfig::new()) - // The object returned by the call to `with_upgrade(MultiplexConfig::new())` can't be used as a + .with_upgrade(libp2p::mplex::MplexConfig::new()) + // The object returned by the call to `with_upgrade(MplexConfig::new())` can't be used as a // `Transport` because the output of the upgrade is not a stream but a controller for // muxing. We have to explicitly call `into_connection_reuse()` in order to turn this into // a `Transport`. diff --git a/libp2p/examples/floodsub.rs b/libp2p/examples/floodsub.rs index ce24e3d9..e86e101d 100644 --- a/libp2p/examples/floodsub.rs +++ b/libp2p/examples/floodsub.rs @@ -71,8 +71,8 @@ fn main() { }) // On top of plaintext or secio, we will use the multiplex protocol. - .with_upgrade(libp2p::mplex::MultiplexConfig::new()) - // The object returned by the call to `with_upgrade(MultiplexConfig::new())` can't be used as a + .with_upgrade(libp2p::mplex::MplexConfig::new()) + // The object returned by the call to `with_upgrade(MplexConfig::new())` can't be used as a // `Transport` because the output of the upgrade is not a stream but a controller for // muxing. We have to explicitly call `into_connection_reuse()` in order to turn this into // a `Transport`. diff --git a/libp2p/examples/kademlia.rs b/libp2p/examples/kademlia.rs index 0b79464e..072af30c 100644 --- a/libp2p/examples/kademlia.rs +++ b/libp2p/examples/kademlia.rs @@ -78,8 +78,8 @@ fn main() { }) // On top of plaintext or secio, we will use the multiplex protocol. - .with_upgrade(libp2p::mplex::MultiplexConfig::new()) - // The object returned by the call to `with_upgrade(MultiplexConfig::new())` can't be used as a + .with_upgrade(libp2p::mplex::MplexConfig::new()) + // The object returned by the call to `with_upgrade(MplexConfig::new())` can't be used as a // `Transport` because the output of the upgrade is not a stream but a controller for // muxing. We have to explicitly call `into_connection_reuse()` in order to turn this into // a `Transport`. diff --git a/libp2p/examples/ping-client.rs b/libp2p/examples/ping-client.rs index b4c13768..60c448ee 100644 --- a/libp2p/examples/ping-client.rs +++ b/libp2p/examples/ping-client.rs @@ -63,8 +63,8 @@ fn main() { }) // On top of plaintext or secio, we will use the multiplex protocol. - .with_upgrade(libp2p::mplex::MultiplexConfig::new()) - // The object returned by the call to `with_upgrade(MultiplexConfig::new())` can't be used as a + .with_upgrade(libp2p::mplex::MplexConfig::new()) + // The object returned by the call to `with_upgrade(MplexConfig::new())` can't be used as a // `Transport` because the output of the upgrade is not a stream but a controller for // muxing. We have to explicitly call `into_connection_reuse()` in order to turn this into // a `Transport`. diff --git a/mplex/Cargo.toml b/mplex/Cargo.toml index 1162fc7b..c3a870b4 100644 --- a/mplex/Cargo.toml +++ b/mplex/Cargo.toml @@ -8,6 +8,7 @@ arrayvec = "0.4.6" bytes = "0.4.5" circular-buffer = { path = "../circular-buffer" } error-chain = "0.11.0" +fnv = "1.0" futures = "0.1" futures-mutex = { git = "https://github.com/paritytech/futures-mutex" } libp2p-core = { path = "../core" } diff --git a/mplex/src/codec.rs b/mplex/src/codec.rs new file mode 100644 index 00000000..a6946b2e --- /dev/null +++ b/mplex/src/codec.rs @@ -0,0 +1,187 @@ +// Copyright 2018 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use std::cmp; +use std::io::{Error as IoError, ErrorKind as IoErrorKind}; +use std::mem; +use bytes::{BufMut, BytesMut}; +use core::Endpoint; +use tokio_io::codec::{Decoder, Encoder}; +use varint; + +// Arbitrary maximum size for a packet. +// Since data is entirely buffered before being dispatched, we need a limit or remotes could just +// send a 4 TB-long packet full of zeroes that we kill our process with an OOM error. +const MAX_FRAME_SIZE: usize = 32 * 1024 * 1024; + +#[derive(Debug, Clone)] +pub enum Elem { + Open { substream_id: u32 }, + Data { substream_id: u32, endpoint: Endpoint, data: BytesMut }, + Close { substream_id: u32, endpoint: Endpoint }, + Reset { substream_id: u32, endpoint: Endpoint }, +} + +impl Elem { + /// Returns the ID of the substream of the message. + pub fn substream_id(&self) -> u32 { + match *self { + Elem::Open { substream_id } => substream_id, + Elem::Data { substream_id, .. } => substream_id, + Elem::Close { substream_id, .. } => substream_id, + Elem::Reset { substream_id, .. } => substream_id, + } + } +} + +pub struct Codec { + varint_decoder: varint::VarintDecoder, + decoder_state: CodecDecodeState, +} + +#[derive(Debug, Clone)] +enum CodecDecodeState { + Begin, + HasHeader(u32), + HasHeaderAndLen(u32, usize, BytesMut), + Poisoned, +} + +impl Codec { + pub fn new() -> Codec { + Codec { + varint_decoder: varint::VarintDecoder::new(), + decoder_state: CodecDecodeState::Begin, + } + } +} + +impl Decoder for Codec { + type Item = Elem; + type Error = IoError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + loop { + match mem::replace(&mut self.decoder_state, CodecDecodeState::Poisoned) { + CodecDecodeState::Begin => { + match self.varint_decoder.decode(src)? { + Some(header) => { + self.decoder_state = CodecDecodeState::HasHeader(header); + }, + None => { + self.decoder_state = CodecDecodeState::Begin; + return Ok(None); + }, + } + }, + CodecDecodeState::HasHeader(header) => { + match self.varint_decoder.decode(src)? { + Some(len) => { + if len as usize > MAX_FRAME_SIZE { + return Err(IoErrorKind::InvalidData.into()); + } + + self.decoder_state = CodecDecodeState::HasHeaderAndLen(header, len as usize, BytesMut::with_capacity(len as usize)); + }, + None => { + self.decoder_state = CodecDecodeState::HasHeader(header); + return Ok(None); + }, + } + }, + CodecDecodeState::HasHeaderAndLen(header, len, mut buf) => { + debug_assert!(len == 0 || buf.len() < len); + let to_transfer = cmp::min(src.len(), len - buf.len()); + + buf.put(src.split_to(to_transfer)); // TODO: more optimal? + + if buf.len() < len { + self.decoder_state = CodecDecodeState::HasHeaderAndLen(header, len, buf); + return Ok(None); + } + + self.decoder_state = CodecDecodeState::Begin; + let substream_id = (header >> 3) as u32; + let out = match header & 7 { + 0 => Elem::Open { substream_id }, + 1 => Elem::Data { substream_id, endpoint: Endpoint::Listener, data: buf }, + 2 => Elem::Data { substream_id, endpoint: Endpoint::Dialer, data: buf }, + 3 => Elem::Close { substream_id, endpoint: Endpoint::Listener }, + 4 => Elem::Close { substream_id, endpoint: Endpoint::Dialer }, + 5 => Elem::Reset { substream_id, endpoint: Endpoint::Listener }, + 6 => Elem::Reset { substream_id, endpoint: Endpoint::Dialer }, + _ => return Err(IoErrorKind::InvalidData.into()), + }; + + return Ok(Some(out)); + }, + + CodecDecodeState::Poisoned => { + return Err(IoErrorKind::InvalidData.into()); + } + } + } + } +} + +impl Encoder for Codec { + type Item = Elem; + type Error = IoError; + + fn encode(&mut self, item: Self::Item, dst: &mut BytesMut) -> Result<(), Self::Error> { + let (header, data) = match item { + Elem::Open { substream_id } => { + ((substream_id as u64) << 3, BytesMut::new()) + }, + Elem::Data { substream_id, endpoint: Endpoint::Listener, data } => { + ((substream_id as u64) << 3 | 1, data) + }, + Elem::Data { substream_id, endpoint: Endpoint::Dialer, data } => { + ((substream_id as u64) << 3 | 2, data) + }, + Elem::Close { substream_id, endpoint: Endpoint::Listener } => { + ((substream_id as u64) << 3 | 3, BytesMut::new()) + }, + Elem::Close { substream_id, endpoint: Endpoint::Dialer } => { + ((substream_id as u64) << 3 | 4, BytesMut::new()) + }, + Elem::Reset { substream_id, endpoint: Endpoint::Listener } => { + ((substream_id as u64) << 3 | 5, BytesMut::new()) + }, + Elem::Reset { substream_id, endpoint: Endpoint::Dialer } => { + ((substream_id as u64) << 3 | 6, BytesMut::new()) + }, + }; + + let header_bytes = varint::encode(header); + let data_len = data.as_ref().len(); + let data_len_bytes = varint::encode(data_len); + + if data_len > MAX_FRAME_SIZE { + return Err(IoError::new(IoErrorKind::InvalidData, "data size exceed maximum")); + } + + dst.reserve(header_bytes.len() + data_len_bytes.len() + data_len); + dst.put(header_bytes); + dst.put(data_len_bytes); + dst.put(data); + Ok(()) + } +} diff --git a/mplex/src/header.rs b/mplex/src/header.rs deleted file mode 100644 index f0346eb3..00000000 --- a/mplex/src/header.rs +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright 2017 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use swarm::Endpoint; - -const FLAG_BITS: usize = 3; -const FLAG_MASK: usize = (1usize << FLAG_BITS) - 1; - -pub mod errors { - error_chain! { - errors { - ParseError - } - } -} - -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub struct MultiplexHeader { - pub packet_type: PacketType, - pub substream_id: u32, -} - -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub enum PacketType { - Open, - Close(Endpoint), - Reset(Endpoint), - Message(Endpoint), -} - -impl MultiplexHeader { - pub fn open(id: u32) -> Self { - MultiplexHeader { - substream_id: id, - packet_type: PacketType::Open, - } - } - - pub fn message(id: u32, end: Endpoint) -> Self { - MultiplexHeader { - substream_id: id, - packet_type: PacketType::Message(end), - } - } - - // TODO: Use `u128` or another large integer type instead of bigint since we never use more than - // `pointer width + FLAG_BITS` bits and unconditionally allocating 1-3 `u32`s for that is - // ridiculous (especially since even for small numbers we have to allocate 1 `u32`). - // If this is the future and `BigUint` is better-optimised (maybe by using `Bytes`) then - // forget it. - pub fn parse(header: u64) -> Result { - use num_traits::cast::ToPrimitive; - - let flags = header & FLAG_MASK as u64; - - let substream_id = (header >> FLAG_BITS) - .to_u32() - .ok_or(errors::ErrorKind::ParseError)?; - - // Yes, this is really how it works. No, I don't know why. - let packet_type = match flags { - 0 => PacketType::Open, - - 1 => PacketType::Message(Endpoint::Listener), - 2 => PacketType::Message(Endpoint::Dialer), - - 3 => PacketType::Close(Endpoint::Listener), - 4 => PacketType::Close(Endpoint::Dialer), - - 5 => PacketType::Reset(Endpoint::Listener), - 6 => PacketType::Reset(Endpoint::Dialer), - - _ => { - use std::io; - - return Err(errors::Error::with_chain( - io::Error::new( - io::ErrorKind::Other, - format!("Unexpected packet type: {}", flags), - ), - errors::ErrorKind::ParseError, - )); - } - }; - - Ok(MultiplexHeader { - substream_id, - packet_type, - }) - } - - pub fn to_u64(&self) -> u64 { - let packet_type_id = match self.packet_type { - PacketType::Open => 0, - - PacketType::Message(Endpoint::Listener) => 1, - PacketType::Message(Endpoint::Dialer) => 2, - - PacketType::Close(Endpoint::Listener) => 3, - PacketType::Close(Endpoint::Dialer) => 4, - - PacketType::Reset(Endpoint::Listener) => 5, - PacketType::Reset(Endpoint::Dialer) => 6, - }; - - let substream_id = (self.substream_id as u64) << FLAG_BITS; - - substream_id | packet_type_id - } -} diff --git a/mplex/src/lib.rs b/mplex/src/lib.rs index 51a29735..5b0224fe 100644 --- a/mplex/src/lib.rs +++ b/mplex/src/lib.rs @@ -1,4 +1,4 @@ -// Copyright 2017 Parity Technologies (UK) Ltd. +// Copyright 2018 Parity Technologies (UK) Ltd. // // Permission is hereby granted, free of charge, to any person obtaining a // copy of this software and associated documentation files (the "Software"), @@ -18,367 +18,71 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -extern crate arrayvec; extern crate bytes; -extern crate circular_buffer; +extern crate fnv; #[macro_use] -extern crate error_chain; extern crate futures; -extern crate futures_mutex; -extern crate libp2p_core as swarm; +extern crate libp2p_core as core; #[macro_use] extern crate log; -extern crate num_bigint; -extern crate num_traits; extern crate parking_lot; -extern crate rand; extern crate tokio_io; extern crate varint; -mod header; -mod read; -mod shared; -mod write; +mod codec; -use bytes::Bytes; -use circular_buffer::Array; -use futures::future::{self, FutureResult}; -use futures::{Async, Future, Poll}; -use futures_mutex::Mutex; -use header::MultiplexHeader; -use read::{read_stream, MultiplexReadState}; -use shared::{buf_from_slice, ByteBuf, MultiplexShared}; -use std::io::{self, Read, Write}; -use std::iter; +use std::{cmp, iter}; +use std::io::{Read, Write, Error as IoError, ErrorKind as IoErrorKind}; +use std::mem; use std::sync::Arc; -use std::sync::atomic::{self, AtomicUsize}; -use swarm::muxing::StreamMuxer; -use swarm::{ConnectionUpgrade, Endpoint}; -use tokio_io::{AsyncRead, AsyncWrite}; -use write::write_stream; +use bytes::Bytes; +use core::{ConnectionUpgrade, Endpoint, StreamMuxer}; +use parking_lot::Mutex; +use fnv::FnvHashSet; +use futures::prelude::*; +use futures::{future, stream::Fuse, task}; +use tokio_io::{AsyncRead, AsyncWrite, codec::Framed}; -// So the multiplex is essentially a distributed finite state machine. -// -// In the first state the header must be read so that we know which substream to hand off the -// upcoming packet to. This is first-come, first-served - whichever substream begins reading the -// packet will be locked into reading the header until it is consumed (this may be changed in the -// future, for example by allowing the streams to cooperate on parsing headers). This implementation -// of `Multiplex` operates under the assumption that all substreams are consumed relatively equally. -// A higher-level wrapper may wrap this and add some level of buffering. -// -// In the second state, the substream ID is known. Only this substream can progress until the packet -// is consumed. +// Maximum number of simultaneously-open substreams. +const MAX_SUBSTREAMS: usize = 1024; +// Maximum number of elements in the internal buffer. +const MAX_BUFFER_LEN: usize = 256; -pub struct Substream { - id: u32, - end: Endpoint, - name: Option, - state: Arc>>, - buffer: Option>, -} +/// Configuration for the multiplexer. +#[derive(Debug, Clone, Default)] +pub struct MplexConfig; -impl Drop for Substream { - fn drop(&mut self) { - let mut lock = self.state.lock().wait().expect("This should never fail"); - - lock.close_stream(self.id); +impl MplexConfig { + /// Builds the default configuration. + #[inline] + pub fn new() -> MplexConfig { + Default::default() } } -impl Substream { - fn new>>( - id: u32, - end: Endpoint, - name: B, - state: Arc>>, - ) -> Self { - let name = name.into(); - - Substream { - id, - end, - name, - state, - buffer: None, - } - } - - pub fn name(&self) -> Option<&Bytes> { - self.name.as_ref() - } - - pub fn id(&self) -> u32 { - self.id - } -} - -// TODO: We always zero the buffer, we should delegate to the inner stream. -impl> Read for Substream { - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let mut lock = match self.state.poll_lock() { - Async::Ready(lock) => lock, - Async::NotReady => return Err(io::ErrorKind::WouldBlock.into()), - }; - - read_stream(&mut lock, (self.id, buf)) - } -} - -impl> AsyncRead for Substream {} - -impl Write for Substream { - fn write(&mut self, buf: &[u8]) -> io::Result { - let mut lock = match self.state.poll_lock() { - Async::Ready(lock) => lock, - Async::NotReady => return Err(io::ErrorKind::WouldBlock.into()), - }; - - let mut buffer = self.buffer - .take() - .unwrap_or_else(|| io::Cursor::new(buf_from_slice(buf))); - - let out = write_stream( - &mut *lock, - write::WriteRequest::substream(MultiplexHeader::message(self.id, self.end)), - &mut buffer, - ); - - if buffer.position() < buffer.get_ref().len() as u64 { - self.buffer = Some(buffer); - } - - out - } - - fn flush(&mut self) -> io::Result<()> { - let mut lock = match self.state.poll_lock() { - Async::Ready(lock) => lock, - Async::NotReady => return Err(io::ErrorKind::WouldBlock.into()), - }; - - lock.stream.flush() - } -} - -impl AsyncWrite for Substream { - fn shutdown(&mut self) -> Poll<(), io::Error> { - Ok(Async::Ready(())) - } -} - -pub struct InboundFuture { - end: Endpoint, - state: Arc>>, -} - -impl> Future for InboundFuture { - type Item = Option>; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - let mut lock = match self.state.poll_lock() { - Async::Ready(lock) => lock, - Async::NotReady => return Ok(Async::NotReady), - }; - - if lock.is_closed() { - return Ok(Async::Ready(None)); - } - - // Attempt to make progress, but don't block if we can't - match read_stream(&mut lock, None) { - Ok(_) => {} - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {} - Err(err) => return Err(err), - } - - let id = if let Some((id, _)) = lock.to_open.iter().next() { - *id - } else { - return Ok(Async::NotReady); - }; - - let name = lock.to_open.remove(&id).expect( - "We just checked that this key exists and we have exclusive access to the map, QED", - ); - - lock.open_stream(id); - - Ok(Async::Ready(Some(Substream::new( - id, - self.end, - name, - Arc::clone(&self.state), - )))) - } -} - -pub struct OutboundFuture { - meta: Arc, - current_id: Option<(io::Cursor, u32)>, - state: Arc>>, -} - -impl OutboundFuture { - fn new(muxer: BufferedMultiplex) -> Self { - OutboundFuture { - current_id: None, - meta: muxer.meta, - state: muxer.state, - } - } -} - -fn nonce_to_id(id: usize, end: Endpoint) -> u32 { - id as u32 * 2 + if end == Endpoint::Dialer { 0 } else { 1 } -} - -impl Future for OutboundFuture { - type Item = Option>; - type Error = io::Error; - - fn poll(&mut self) -> Poll { - let mut lock = match self.state.poll_lock() { - Async::Ready(lock) => lock, - Async::NotReady => return Ok(Async::NotReady), - }; - - if lock.is_closed() { - return Ok(Async::Ready(None)); - } - - loop { - let (mut id_str, id) = self.current_id.take().unwrap_or_else(|| { - let next = nonce_to_id( - self.meta.nonce.fetch_add(1, atomic::Ordering::Relaxed), - self.meta.end, - ); - ( - io::Cursor::new(buf_from_slice(format!("{}", next).as_bytes())), - next as u32, - ) - }); - - match write_stream( - &mut *lock, - write::WriteRequest::meta(MultiplexHeader::open(id)), - &mut id_str, - ) { - Ok(_) => { - debug_assert!(id_str.position() <= id_str.get_ref().len() as u64); - if id_str.position() == id_str.get_ref().len() as u64 { - if lock.open_stream(id) { - return Ok(Async::Ready(Some(Substream::new( - id, - self.meta.end, - Bytes::from(&id_str.get_ref()[..]), - Arc::clone(&self.state), - )))); - } - } else { - self.current_id = Some((id_str, id)); - return Ok(Async::NotReady); - } - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - self.current_id = Some((id_str, id)); - - return Ok(Async::NotReady); - } - Err(other) => return Err(other), - } - } - } -} - -pub struct MultiplexMetadata { - nonce: AtomicUsize, - end: Endpoint, -} - -pub type Multiplex = BufferedMultiplex; - -pub struct BufferedMultiplex { - meta: Arc, - state: Arc>>, -} - -impl Clone for BufferedMultiplex { - fn clone(&self) -> Self { - BufferedMultiplex { - meta: self.meta.clone(), - state: self.state.clone(), - } - } -} - -impl BufferedMultiplex { - pub fn new(stream: T, end: Endpoint) -> Self { - BufferedMultiplex { - meta: Arc::new(MultiplexMetadata { - nonce: AtomicUsize::new(0), - end, - }), - state: Arc::new(Mutex::new(MultiplexShared::new(stream))), - } - } - - pub fn dial(stream: T) -> Self { - Self::new(stream, Endpoint::Dialer) - } - - pub fn listen(stream: T) -> Self { - Self::new(stream, Endpoint::Listener) - } -} - -impl> StreamMuxer for BufferedMultiplex { - type Substream = Substream; - type OutboundSubstream = OutboundFuture; - type InboundSubstream = InboundFuture; - - fn inbound(self) -> Self::InboundSubstream { - InboundFuture { - state: Arc::clone(&self.state), - end: self.meta.end, - } - } - - fn outbound(self) -> Self::OutboundSubstream { - OutboundFuture::new(self) - } -} - -pub type MultiplexConfig = BufferedMultiplexConfig<[u8; 0]>; - -#[derive(Debug, Copy, Clone)] -pub struct BufferedMultiplexConfig(std::marker::PhantomData); - -impl Default for BufferedMultiplexConfig { - fn default() -> Self { - BufferedMultiplexConfig(std::marker::PhantomData) - } -} - -impl BufferedMultiplexConfig { - pub fn new() -> Self { - Self::default() - } -} - -impl ConnectionUpgrade for BufferedMultiplexConfig +impl ConnectionUpgrade for MplexConfig where C: AsyncRead + AsyncWrite, { - type Output = BufferedMultiplex; + type Output = Multiplex; type MultiaddrFuture = Maf; - type Future = FutureResult<(BufferedMultiplex, Maf), io::Error>; + type Future = future::FutureResult<(Self::Output, Self::MultiaddrFuture), IoError>; type UpgradeIdentifier = (); type NamesIter = iter::Once<(Bytes, ())>; #[inline] - fn upgrade(self, i: C, _: (), end: Endpoint, remote_addr: Maf) -> Self::Future { - future::ok((BufferedMultiplex::new(i, end), remote_addr)) + fn upgrade(self, i: C, _: (), endpoint: Endpoint, remote_addr: Maf) -> Self::Future { + let out = Multiplex { + inner: Arc::new(Mutex::new(MultiplexInner { + inner: i.framed(codec::Codec::new()).fuse(), + buffer: Vec::with_capacity(32), + opened_substreams: Default::default(), + next_outbound_stream_id: if endpoint == Endpoint::Dialer { 0 } else { 1 }, + to_notify: Vec::new(), + })) + }; + + future::ok((out, remote_addr)) } #[inline] @@ -387,361 +91,302 @@ where } } -#[cfg(test)] -mod tests { - use super::*; - use header::PacketType; - use std::io; - use tokio_io::io as tokio; +/// Multiplexer. Implements the `StreamMuxer` trait. +pub struct Multiplex { + inner: Arc>>, +} - #[test] - fn can_use_one_stream() { - let message = b"Hello, world!"; - - let stream = io::Cursor::new(Vec::new()); - - let mplex = Multiplex::dial(stream); - - let mut substream = mplex - .clone() - .outbound() - .wait() - .unwrap() - .expect("outbound substream"); - - assert!(tokio::write_all(&mut substream, message).wait().is_ok()); - - let id = substream.id(); - - assert_eq!( - substream - .name() - .and_then(|bytes| String::from_utf8(bytes.to_vec()).ok()), - Some(id.to_string()) - ); - - let stream = io::Cursor::new(mplex.state.lock().wait().unwrap().stream.get_ref().clone()); - - let mplex = Multiplex::listen(stream); - - let mut substream = mplex.inbound().wait().unwrap().expect("inbound substream"); - - assert_eq!(id, substream.id()); - assert_eq!( - substream - .name() - .and_then(|bytes| String::from_utf8(bytes.to_vec()).ok()), - Some(id.to_string()) - ); - - let mut buf = vec![0; message.len()]; - - assert!(tokio::read(&mut substream, &mut buf).wait().is_ok()); - assert_eq!(&buf, message); - } - - #[test] - fn can_use_many_streams() { - let stream = io::Cursor::new(Vec::new()); - - let mplex = Multiplex::dial(stream); - - let mut outbound: Vec> = vec![]; - - for _ in 0..5 { - outbound.push( - mplex - .clone() - .outbound() - .wait() - .unwrap() - .expect("outbound substream"), - ); +impl Clone for Multiplex { + #[inline] + fn clone(&self) -> Self { + Multiplex { + inner: self.inner.clone(), } - - outbound.sort_by_key(|a| a.id()); - - for (i, substream) in outbound.iter_mut().enumerate() { - assert!( - tokio::write_all(substream, i.to_string().as_bytes()) - .wait() - .is_ok() - ); - } - - let stream = io::Cursor::new(mplex.state.lock().wait().unwrap().stream.get_ref().clone()); - - let mplex = Multiplex::listen(stream); - - let mut inbound: Vec> = vec![]; - - for _ in 0..5 { - inbound.push( - mplex - .clone() - .inbound() - .wait() - .unwrap() - .expect("inbound substream"), - ); - } - - inbound.sort_by_key(|a| a.id()); - - for (mut substream, outbound) in inbound.iter_mut().zip(outbound.iter()) { - let id = outbound.id(); - assert_eq!(id, substream.id()); - assert_eq!( - substream - .name() - .and_then(|bytes| String::from_utf8(bytes.to_vec()).ok()), - Some(id.to_string()) - ); - - let mut buf = [0; 3]; - assert_eq!(tokio::read(&mut substream, &mut buf).wait().unwrap().2, 1); - } - } - - #[test] - fn packets_to_unopened_streams_are_dropped() { - use std::iter; - - let message = b"Hello, world!"; - - // We use a large dummy length to exercise ignoring data longer than `ignore_buffer.len()` - let dummy_length = 1000; - - let input = iter::empty() - // Open a stream - .chain(varint::encode(MultiplexHeader::open(0).to_u64())) - // 0-length body (stream has no name) - .chain(varint::encode(0usize)) - - // "Message"-type packet for an unopened stream - .chain( - varint::encode( - // ID for an unopened stream: 1 - MultiplexHeader::message(1, Endpoint::Dialer).to_u64(), - ).into_iter(), - ) - // Body: `dummy_length` of zeroes - .chain(varint::encode(dummy_length)) - .chain(iter::repeat(0).take(dummy_length)) - - // "Message"-type packet for an opened stream - .chain( - varint::encode( - // ID for an opened stream: 0 - MultiplexHeader::message(0, Endpoint::Dialer).to_u64(), - ).into_iter(), - ) - .chain(varint::encode(message.len())) - .chain(message.iter().cloned()) - - .collect::>(); - - let mplex = Multiplex::listen(io::Cursor::new(input)); - - let mut substream = mplex.inbound().wait().unwrap().expect("inbound substream"); - - assert_eq!(substream.id(), 0); - assert_eq!(substream.name(), None); - - let mut buf = vec![0; message.len()]; - - assert!(tokio::read(&mut substream, &mut buf).wait().is_ok()); - assert_eq!(&buf, message); - } - - #[test] - fn can_close_streams() { - use std::iter; - - // Dummy data in the body of the close packet (since the de facto protocol is to accept but - // ignore this data) - let dummy_length = 64; - - let input = iter::empty() - // Open a stream - .chain(varint::encode(MultiplexHeader::open(0).to_u64())) - // 0-length body (stream has no name) - .chain(varint::encode(0usize)) - - // Immediately close the stream - .chain( - varint::encode( - // ID for an unopened stream: 1 - MultiplexHeader { - packet_type: PacketType::Close(Endpoint::Dialer), - substream_id: 0, - }.to_u64(), - ).into_iter(), - ) - .chain(varint::encode(dummy_length)) - .chain(iter::repeat(0).take(dummy_length)) - - // Send packet to the closed stream - .chain( - varint::encode( - // ID for an opened stream: 0 - MultiplexHeader::message(0, Endpoint::Dialer).to_u64(), - ).into_iter(), - ) - .chain(varint::encode(dummy_length)) - .chain(iter::repeat(0).take(dummy_length)) - - .collect::>(); - - let mplex = Multiplex::listen(io::Cursor::new(input)); - - let mut substream = mplex.inbound().wait().unwrap().expect("inbound substream"); - - assert_eq!(substream.id(), 0); - assert_eq!(substream.name(), None); - - assert_eq!( - tokio::read(&mut substream, &mut [0; 100][..]) - .wait() - .unwrap() - .2, - 0 - ); - } - - #[test] - fn real_world_data() { - let data: Vec = vec![ - // Open stream 1 - 8, - 0, - - // Message for stream 1 (length 20) - 10, - 20, - 19, - 47, - 109, - 117, - 108, - 116, - 105, - 115, - 116, - 114, - 101, - 97, - 109, - 47, - 49, - 46, - 48, - 46, - 48, - 10, - ]; - - let mplex = Multiplex::listen(io::Cursor::new(data)); - - let mut substream = mplex.inbound().wait().unwrap().expect("inbound substream"); - - assert_eq!(substream.id(), 1); - - assert_eq!(substream.name(), None); - - let mut out = vec![]; - - for _ in 0..20 { - let mut buf = [0; 1]; - - assert_eq!( - tokio::read(&mut substream, &mut buf[..]).wait().unwrap().2, - 1 - ); - - out.push(buf[0]); - } - - assert_eq!(out[0], 19); - assert_eq!(&out[1..0x14 - 1], b"/multistream/1.0.0"); - assert_eq!(out[0x14 - 1], 0x0a); - } - - #[test] - fn can_buffer() { - type Buffer = [u8; 1024]; - - let stream = io::Cursor::new(Vec::new()); - - let mplex = BufferedMultiplex::<_, Buffer>::dial(stream); - - let mut outbound: Vec> = vec![]; - - for _ in 0..5 { - outbound.push( - mplex - .clone() - .outbound() - .wait() - .unwrap() - .expect("outbound substream"), - ); - } - - outbound.sort_by_key(|a| a.id()); - - for (i, substream) in outbound.iter_mut().enumerate() { - assert!( - tokio::write_all(substream, i.to_string().as_bytes()) - .wait() - .is_ok() - ); - } - - let stream = io::Cursor::new(mplex.state.lock().wait().unwrap().stream.get_ref().clone()); - - let mplex = BufferedMultiplex::<_, Buffer>::listen(stream); - - let mut inbound: Vec> = vec![]; - - for _ in 0..5 { - let inb: Substream<_, Buffer> = mplex - .clone() - .inbound() - .wait() - .unwrap() - .expect("inbound substream"); - inbound.push(inb); - } - - inbound.sort_by_key(|a| a.id()); - - // Skip the first substream and let it be cached. - for (mut substream, outbound) in inbound.iter_mut().zip(outbound.iter()).skip(1) { - let id = outbound.id(); - assert_eq!(id, substream.id()); - assert_eq!( - substream - .name() - .and_then(|bytes| String::from_utf8(bytes.to_vec()).ok()), - Some(id.to_string()) - ); - - let mut buf = [0; 3]; - assert_eq!(tokio::read(&mut substream, &mut buf).wait().unwrap().2, 1); - } - - let (mut substream, outbound) = (&mut inbound[0], &outbound[0]); - let id = outbound.id(); - assert_eq!(id, substream.id()); - assert_eq!( - substream - .name() - .and_then(|bytes| String::from_utf8(bytes.to_vec()).ok()), - Some(id.to_string()) - ); - - let mut buf = [0; 3]; - assert_eq!(tokio::read(&mut substream, &mut buf).wait().unwrap().2, 1); + } +} + +// Struct shared throughout the implementation. +struct MultiplexInner { + // Underlying stream. + inner: Fuse>, + // Buffer of elements pulled from the stream but not processed yet. + buffer: Vec, + // List of Ids of opened substreams. Used to filter out messages that don't belong to any + // substream. + opened_substreams: FnvHashSet, + // Id of the next outgoing substream. Should always increase by two. + next_outbound_stream_id: u32, + // List of tasks to notify when a new element is inserted in `buffer`. + to_notify: Vec, +} + +// Processes elements in `inner` until one matching `filter` is found. +// +// If `NotReady` is returned, the current task is scheduled for later, just like with any `Poll`. +// `Ready(Some())` is almost always returned. `Ready(None)` is returned if the stream is EOF. +fn next_match(inner: &mut MultiplexInner, mut filter: F) -> Poll, IoError> +where C: AsyncRead + AsyncWrite, + F: FnMut(&codec::Elem) -> Option, +{ + if let Some((offset, out)) = inner.buffer.iter().enumerate().filter_map(|(n, v)| filter(v).map(|v| (n, v))).next() { + inner.buffer.remove(offset); + return Ok(Async::Ready(Some(out))); + } + + loop { + let elem = match inner.inner.poll() { + Ok(Async::Ready(item)) => item, + Ok(Async::NotReady) => { + inner.to_notify.push(task::current()); + return Ok(Async::NotReady); + }, + Err(err) => { + return Err(err); + }, + }; + + if let Some(elem) = elem { + if let Some(out) = filter(&elem) { + return Ok(Async::Ready(Some(out))); + } else { + if inner.buffer.len() >= MAX_BUFFER_LEN { + return Err(IoError::new(IoErrorKind::InvalidData, "reached maximum buffer length")); + } + + if inner.opened_substreams.contains(&elem.substream_id()) { + inner.buffer.push(elem); + for task in inner.to_notify.drain(..) { + task.notify(); + } + } + } + } else { + return Ok(Async::Ready(None)); + } + } +} + +// Closes a substream in `inner`. +fn clean_out_substream(inner: &mut MultiplexInner, num: u32) { + let was_in = inner.opened_substreams.remove(&num); + debug_assert!(was_in, "Dropped substream which wasn't open ; programmer error"); + inner.buffer.retain(|elem| elem.substream_id() != num); +} + +// Small convenience function that tries to write `elem` to the stream. +fn poll_send(inner: &mut MultiplexInner, elem: codec::Elem) -> Poll<(), IoError> +where C: AsyncRead + AsyncWrite +{ + match inner.inner.start_send(elem) { + Ok(AsyncSink::Ready) => { + Ok(Async::Ready(())) + }, + Ok(AsyncSink::NotReady(_)) => { + Ok(Async::NotReady) + }, + Err(err) => Err(err) + } +} + +impl StreamMuxer for Multiplex +where C: AsyncRead + AsyncWrite + 'static // TODO: 'static :-/ +{ + type Substream = Substream; + type InboundSubstream = InboundSubstream; + type OutboundSubstream = Box, Error = IoError> + 'static>; + + #[inline] + fn inbound(self) -> Self::InboundSubstream { + InboundSubstream { inner: self.inner } + } + + #[inline] + fn outbound(self) -> Self::OutboundSubstream { + let mut inner = self.inner.lock(); + + // Assign a substream ID now. + let substream_id = { + let n = inner.next_outbound_stream_id; + inner.next_outbound_stream_id += 2; + n + }; + + // We use an RAII guard, so that we close the substream in case of an error. + struct OpenedSubstreamGuard(Arc>>, u32); + impl Drop for OpenedSubstreamGuard { + fn drop(&mut self) { clean_out_substream(&mut self.0.lock(), self.1); } + } + inner.opened_substreams.insert(substream_id); + let guard = OpenedSubstreamGuard(self.inner.clone(), substream_id); + + // We send `Open { substream_id }`, then flush, then only produce the substream. + let future = { + future::poll_fn({ + let inner = self.inner.clone(); + move || { + let elem = codec::Elem::Open { substream_id }; + poll_send(&mut inner.lock(), elem) + } + }).and_then({ + let inner = self.inner.clone(); + move |()| { + future::poll_fn(move || inner.lock().inner.poll_complete()) + } + }).map({ + let inner = self.inner.clone(); + move |()| { + mem::forget(guard); + Some(Substream { + inner: inner.clone(), + num: substream_id, + current_data: Bytes::new(), + endpoint: Endpoint::Dialer, + }) + } + }) + }; + + Box::new(future) as Box<_> + } +} + +/// Future to the next incoming substream. +pub struct InboundSubstream { + inner: Arc>>, +} + +impl Future for InboundSubstream +where C: AsyncRead + AsyncWrite +{ + type Item = Option>; + type Error = IoError; + + fn poll(&mut self) -> Poll { + let mut inner = self.inner.lock(); + + if inner.opened_substreams.len() >= MAX_SUBSTREAMS { + return Err(IoError::new(IoErrorKind::ConnectionRefused, + "exceeded maximum number of open substreams")); + } + + let num = try_ready!(next_match(&mut inner, |elem| { + match elem { + codec::Elem::Open { substream_id } => Some(*substream_id), // TODO: check even/uneven? + _ => None, + } + })); + + if let Some(num) = num { + inner.opened_substreams.insert(num); + Ok(Async::Ready(Some(Substream { + inner: self.inner.clone(), + current_data: Bytes::new(), + num, + endpoint: Endpoint::Listener, + }))) + } else { + Ok(Async::Ready(None)) + } + } +} + +/// Active substream to the remote. Implements `AsyncRead` and `AsyncWrite`. +pub struct Substream +where C: AsyncRead + AsyncWrite +{ + inner: Arc>>, + num: u32, + // Read buffer. Contains data read from `inner` but not yet dispatched by a call to `read()`. + current_data: Bytes, + endpoint: Endpoint, +} + +impl Read for Substream +where C: AsyncRead + AsyncWrite +{ + fn read(&mut self, buf: &mut [u8]) -> Result { + loop { + // First transfer from `current_data`. + if self.current_data.len() != 0 { + let len = cmp::min(self.current_data.len(), buf.len()); + buf[..len].copy_from_slice(&self.current_data.split_to(len)); + return Ok(len); + } + + let mut inner = self.inner.lock(); + let next_data_poll = next_match(&mut inner, |elem| { + match elem { + &codec::Elem::Data { ref substream_id, ref data, .. } if *substream_id == self.num => { // TODO: check endpoint? + Some(data.clone()) + }, + _ => None, + } + }); + + // We're in a loop, so all we need to do is set `self.current_data` to the data we + // just read and wait for the next iteration. + match next_data_poll { + Ok(Async::Ready(Some(data))) => self.current_data = data.freeze(), + Ok(Async::Ready(None)) => return Ok(0), + Ok(Async::NotReady) => return Err(IoErrorKind::WouldBlock.into()), + Err(err) => return Err(err), + } + } + } +} + +impl AsyncRead for Substream +where C: AsyncRead + AsyncWrite +{ +} + +impl Write for Substream +where C: AsyncRead + AsyncWrite +{ + fn write(&mut self, buf: &[u8]) -> Result { + let elem = codec::Elem::Data { + substream_id: self.num, + data: From::from(buf), + endpoint: self.endpoint, + }; + + let mut inner = self.inner.lock(); + match poll_send(&mut inner, elem) { + Ok(Async::Ready(())) => Ok(buf.len()), + Ok(Async::NotReady) => Err(IoErrorKind::WouldBlock.into()), + Err(err) => Err(err), + } + } + + fn flush(&mut self) -> Result<(), IoError> { + let mut inner = self.inner.lock(); + match inner.inner.poll_complete() { + Ok(Async::Ready(())) => Ok(()), + Ok(Async::NotReady) => Err(IoErrorKind::WouldBlock.into()), + Err(err) => Err(err), + } + } +} + +impl AsyncWrite for Substream +where C: AsyncRead + AsyncWrite +{ + fn shutdown(&mut self) -> Poll<(), IoError> { + let elem = codec::Elem::Close { + substream_id: self.num, + endpoint: self.endpoint, + }; + + let mut inner = self.inner.lock(); + poll_send(&mut inner, elem) + } +} + +impl Drop for Substream +where C: AsyncRead + AsyncWrite +{ + fn drop(&mut self) { + let _ = self.shutdown(); + clean_out_substream(&mut self.inner.lock(), self.num); } } diff --git a/mplex/src/read.rs b/mplex/src/read.rs deleted file mode 100644 index 47459f6f..00000000 --- a/mplex/src/read.rs +++ /dev/null @@ -1,542 +0,0 @@ -// Copyright 2017 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use circular_buffer::Array; -use futures::Async; -use futures::task; -use header::{MultiplexHeader, PacketType}; -use shared::SubstreamMetadata; -use std::io; -use tokio_io::AsyncRead; -use {bytes, varint}; - -pub enum NextMultiplexState { - NewStream(u32), - ParsingMessageBody(u32), - Ignore(u32), -} - -impl NextMultiplexState { - pub fn substream_id(&self) -> u32 { - match *self { - NextMultiplexState::NewStream(id) - | NextMultiplexState::ParsingMessageBody(id) - | NextMultiplexState::Ignore(id) => id, - } - } -} - -pub enum MultiplexReadState { - Header { - state: varint::DecoderState, - }, - BodyLength { - state: varint::DecoderState, - next: NextMultiplexState, - }, - NewStream { - substream_id: u32, - name: bytes::BytesMut, - remaining_bytes: usize, - }, - ParsingMessageBody { - substream_id: u32, - remaining_bytes: usize, - }, - Ignore { - substream_id: u32, - remaining_bytes: usize, - }, -} - -impl Default for MultiplexReadState { - fn default() -> Self { - MultiplexReadState::Header { - state: Default::default(), - } - } -} - -fn create_buffer(capacity: usize) -> bytes::BytesMut { - let mut buffer = bytes::BytesMut::with_capacity(capacity); - let zeroes = [0; 1024]; - let mut cap = capacity; - - while cap > 0 { - let len = cap.min(zeroes.len()); - buffer.extend_from_slice(&zeroes[..len]); - cap -= len; - } - - buffer -} - -fn block_on_wrong_stream>( - substream_id: u32, - remaining_bytes: usize, - lock: &mut ::shared::MultiplexShared, -) -> io::Result { - use std::{mem, slice}; - - lock.read_state = Some(MultiplexReadState::ParsingMessageBody { - substream_id, - remaining_bytes, - }); - - let mut out_consumed = 0; - let mut stream_eof = false; - if let Some((tasks, cache)) = lock.open_streams - .entry(substream_id) - .or_insert_with(|| SubstreamMetadata::new_open()) - .open_meta_mut() - .map(|cur| { - ( - mem::replace(&mut cur.read, Default::default()), - &mut cur.read_cache, - ) - }) { - // We check `cache.capacity()` since that can totally statically remove this branch in the - // `== 0` path. - if cache.capacity() > 0 && cache.len() < cache.capacity() { - let mut buf: Buf = unsafe { mem::uninitialized() }; - - // Can't fail because `cache.len() >= 0`, - // `cache.len() <= cache.capacity()` and - // `cache.capacity() == mem::size_of::()` - let buf_prefix = unsafe { - let max_that_fits_in_buffer = cache.capacity() - cache.len(); - // We know this won't panic because of the earlier - // `number_read >= buf.len()` check - let new_len = max_that_fits_in_buffer.min(remaining_bytes); - - slice::from_raw_parts_mut(buf.ptr_mut(), new_len) - }; - - match lock.stream.read(buf_prefix) { - Ok(consumed) => { - if consumed == 0 && !buf_prefix.is_empty() { - stream_eof = true - } - - let new_remaining = remaining_bytes - consumed; - - assert!(cache.extend_from_slice(&buf_prefix[..consumed])); - - out_consumed = consumed; - - lock.read_state = Some(MultiplexReadState::ParsingMessageBody { - substream_id, - remaining_bytes: new_remaining, - }); - } - Err(err) => { - if err.kind() != io::ErrorKind::WouldBlock { - for task in tasks { - task.notify(); - } - - return Err(err); - } - } - } - } - - for task in tasks { - task.notify(); - } - } - - if stream_eof { - lock.close() - } - - Ok(out_consumed) -} - -pub fn read_stream< - 'a, - Buf: Array, - O: Into>, - T: AsyncRead, ->( - lock: &mut ::shared::MultiplexShared, - stream_data: O, -) -> io::Result { - read_stream_internal(lock, stream_data.into()) -} - -fn read_stream_internal>( - lock: &mut ::shared::MultiplexShared, - mut stream_data: Option<(u32, &mut [u8])>, -) -> io::Result { - use self::MultiplexReadState::*; - - // This is only true if a stream exists and it has been closed in a "graceful" manner, so we - // can return `Ok(0)` like the `Read` trait requests. In any other case we want to return - // `WouldBlock` - let stream_has_been_gracefully_closed = stream_data - .as_ref() - .and_then(|&(id, _)| lock.open_streams.get(&id)) - .map(|meta| !meta.open()) - .unwrap_or(false); - - let mut on_block: io::Result = if stream_has_been_gracefully_closed { - Ok(0) - } else { - Err(io::ErrorKind::WouldBlock.into()) - }; - - if let Some((ref mut id, ref mut buf)) = stream_data { - if let Some(cur) = lock.open_streams - .entry(*id) - .or_insert_with(|| SubstreamMetadata::new_open()) - .open_meta_mut() - { - cur.read.push(task::current()); - - let cache = &mut cur.read_cache; - - if !cache.is_empty() { - let mut consumed = 0; - loop { - let cur_buf = &mut buf[consumed..]; - if cur_buf.is_empty() { - break; - } - - if let Some(out) = cache.pop_first_n_leaky(cur_buf.len()) { - cur_buf[..out.len()].copy_from_slice(out); - consumed += out.len(); - } else { - break; - }; - } - - on_block = Ok(consumed); - } - } - } - - loop { - match lock.read_state.take().unwrap_or_default() { - Header { - state: mut varint_state, - } => { - match varint_state.read(&mut lock.stream) { - Ok(Async::Ready(header)) => { - let header = if let Some(header) = header { - header - } else { - lock.close(); - return Ok(on_block.unwrap_or(0)); - }; - - let MultiplexHeader { - substream_id, - packet_type, - } = MultiplexHeader::parse(header).map_err(|err| { - debug!("failed to parse header: {}", err); - io::Error::new( - io::ErrorKind::Other, - format!("Error parsing header: {:?}", err), - ) - })?; - - match packet_type { - PacketType::Open => { - lock.read_state = Some(BodyLength { - state: Default::default(), - next: NextMultiplexState::NewStream(substream_id), - }) - } - PacketType::Message(_) => { - lock.read_state = Some(BodyLength { - state: Default::default(), - next: NextMultiplexState::ParsingMessageBody(substream_id), - }) - } - // NOTE: What's the difference between close and reset? - PacketType::Close(_) | PacketType::Reset(_) => { - lock.read_state = Some(BodyLength { - state: Default::default(), - next: NextMultiplexState::Ignore(substream_id), - }); - - lock.close_stream(substream_id); - } - } - } - Ok(Async::NotReady) => { - lock.read_state = Some(Header { - state: varint_state, - }); - return on_block; - } - Err(error) => { - return if let varint::Error(varint::ErrorKind::Io(inner), ..) = error { - debug!("failed to read header: {}", inner); - Err(inner) - } else { - debug!("failed to read header: {}", error); - Err(io::Error::new(io::ErrorKind::Other, error.description())) - }; - } - } - } - BodyLength { - state: mut varint_state, - next, - } => { - use self::NextMultiplexState::*; - - let body_len = varint_state.read(&mut lock.stream).map_err(|e| { - debug!("substream {}: failed to read body length: {}", next.substream_id(), e); - io::Error::new(io::ErrorKind::Other, "Error reading varint") - })?; - - match body_len { - Async::Ready(length) => { - // TODO: Limit `length` to prevent resource-exhaustion DOS - let length = if let Some(length) = length { - length - } else { - lock.close(); - return Ok(on_block.unwrap_or(0)); - }; - - lock.read_state = match next { - Ignore(substream_id) => Some(MultiplexReadState::Ignore { - substream_id, - remaining_bytes: length, - }), - NewStream(substream_id) => { - if length == 0 { - lock.to_open.insert(substream_id, None); - - None - } else { - Some(MultiplexReadState::NewStream { - // TODO: Uninit buffer - name: create_buffer(length), - remaining_bytes: length, - substream_id, - }) - } - } - ParsingMessageBody(substream_id) => { - let is_open = lock.open_streams - .get(&substream_id) - .map(SubstreamMetadata::open) - .unwrap_or_else(|| lock.to_open.contains_key(&substream_id)); - - if is_open { - Some(MultiplexReadState::ParsingMessageBody { - remaining_bytes: length, - substream_id, - }) - } else { - Some(MultiplexReadState::Ignore { - substream_id, - remaining_bytes: length, - }) - } - } - }; - } - Async::NotReady => { - lock.read_state = Some(BodyLength { - state: varint_state, - next, - }); - - return on_block; - } - } - } - NewStream { - substream_id, - mut name, - remaining_bytes, - } => { - if remaining_bytes == 0 { - lock.to_open.insert(substream_id, Some(name.freeze())); - - lock.read_state = None; - } else { - let cursor_pos = name.len() - remaining_bytes; - let consumed = lock.stream.read(&mut name[cursor_pos..]); - - match consumed { - Ok(consumed) => { - if consumed == 0 { - lock.close() - } - - let new_remaining = remaining_bytes - consumed; - - lock.read_state = Some(NewStream { - substream_id, - name, - remaining_bytes: new_remaining, - }); - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - lock.read_state = Some(NewStream { - substream_id, - name, - remaining_bytes, - }); - - return on_block; - } - Err(other) => { - debug!("substream {}: failed to read new stream: {}", - substream_id, - other); - lock.read_state = Some(NewStream { - substream_id, - name, - remaining_bytes, - }); - return Err(other); - } - } - } - } - ParsingMessageBody { - substream_id, - remaining_bytes, - } => { - if let Some((ref mut id, ref mut buf)) = stream_data { - use MultiplexReadState::*; - - let number_read = *on_block.as_ref().unwrap_or(&0); - - if remaining_bytes == 0 { - lock.read_state = None; - } else if substream_id == *id { - if number_read >= buf.len() { - lock.read_state = Some(ParsingMessageBody { - substream_id, - remaining_bytes, - }); - - return Ok(number_read); - } - - let read_result = { - // We know this won't panic because of the earlier - // `number_read >= buf.len()` check - let new_len = (buf.len() - number_read).min(remaining_bytes); - let slice = &mut buf[number_read..number_read + new_len]; - - lock.stream.read(slice) - }; - - lock.read_state = Some(ParsingMessageBody { - substream_id, - remaining_bytes, - }); - - match read_result { - Ok(consumed) => { - if consumed == 0 { - lock.close() - } - - let new_remaining = remaining_bytes - consumed; - - lock.read_state = Some(ParsingMessageBody { - substream_id, - remaining_bytes: new_remaining, - }); - - on_block = Ok(number_read + consumed); - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - return on_block; - } - Err(other) => { - debug!("substream {}: failed to read message body: {}", - substream_id, - other); - return Err(other); - } - } - } else { - // We cannot make progress here, another stream has to accept this packet - if block_on_wrong_stream(substream_id, remaining_bytes, lock)? == 0 { - return on_block; - } - } - } else { - // We cannot make progress here, another stream has to accept this packet - if block_on_wrong_stream(substream_id, remaining_bytes, lock)? == 0 { - return on_block; - } - } - } - Ignore { - substream_id, - mut remaining_bytes, - } => { - let mut ignore_buf: [u8; 256] = [0; 256]; - - loop { - if remaining_bytes == 0 { - lock.read_state = None; - break; - } else { - let new_len = ignore_buf.len().min(remaining_bytes); - match lock.stream.read(&mut ignore_buf[..new_len]) { - Ok(consumed) => { - if consumed == 0 { - lock.close() - } - remaining_bytes -= consumed; - lock.read_state = Some(Ignore { - substream_id, - remaining_bytes: remaining_bytes, - }); - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - lock.read_state = Some(Ignore { - substream_id, - remaining_bytes, - }); - return on_block; - } - Err(other) => { - debug!("substream {}: failed to read ignore bytes: {}", - substream_id, - other); - lock.read_state = Some(Ignore { - substream_id, - remaining_bytes, - }); - return Err(other); - } - } - } - } - } - } - } -} diff --git a/mplex/src/shared.rs b/mplex/src/shared.rs deleted file mode 100644 index a7d725d5..00000000 --- a/mplex/src/shared.rs +++ /dev/null @@ -1,124 +0,0 @@ -// Copyright 2017 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use read::MultiplexReadState; -use write::MultiplexWriteState; - -use arrayvec::ArrayVec; -use bytes::Bytes; -use circular_buffer::{Array, CircularBuffer}; -use futures::task::Task; -use std::collections::HashMap; - -const BUF_SIZE: usize = 1024; - -pub type ByteBuf = ArrayVec<[u8; BUF_SIZE]>; - -pub enum SubstreamMetadata { - Closed, - Open(OpenSubstreamMetadata), -} - -pub struct OpenSubstreamMetadata { - pub read_cache: CircularBuffer, - pub read: Vec, - pub write: Vec, -} - -impl SubstreamMetadata { - pub fn new_open() -> Self { - SubstreamMetadata::Open(OpenSubstreamMetadata { - read_cache: Default::default(), - read: Default::default(), - write: Default::default(), - }) - } - - pub fn open(&self) -> bool { - match *self { - SubstreamMetadata::Closed => false, - SubstreamMetadata::Open { .. } => true, - } - } - - pub fn open_meta_mut(&mut self) -> Option<&mut OpenSubstreamMetadata> { - match *self { - SubstreamMetadata::Closed => None, - SubstreamMetadata::Open(ref mut meta) => Some(meta), - } - } -} - -// TODO: Split reading and writing into different structs and have information shared between the -// two in a `RwLock`, since `open_streams` and `to_open` are mostly read-only. -pub struct MultiplexShared { - // We use `Option` in order to take ownership of heap allocations within `DecoderState` and - // `BytesMut`. If this is ever observably `None` then something has panicked or the underlying - // stream returned an error. - pub read_state: Option, - pub write_state: Option, - pub stream: T, - eof: bool, // true, if `stream` has been exhausted - pub open_streams: HashMap>, - pub meta_write_tasks: Vec, - // TODO: Should we use a version of this with a fixed size that doesn't allocate and return - // `WouldBlock` if it's full? Even if we ignore or size-cap names you can still open 2^32 - // streams. - pub to_open: HashMap>, -} - -impl MultiplexShared { - pub fn new(stream: T) -> Self { - MultiplexShared { - read_state: Default::default(), - write_state: Default::default(), - open_streams: Default::default(), - meta_write_tasks: Default::default(), - to_open: Default::default(), - stream: stream, - eof: false, - } - } - - pub fn open_stream(&mut self, id: u32) -> bool { - trace!("open stream {}", id); - self.open_streams - .entry(id) - .or_insert(SubstreamMetadata::new_open()) - .open() - } - - pub fn close_stream(&mut self, id: u32) { - trace!("close stream {}", id); - self.open_streams.insert(id, SubstreamMetadata::Closed); - } - - pub fn close(&mut self) { - self.eof = true - } - - pub fn is_closed(&self) -> bool { - self.eof - } -} - -pub fn buf_from_slice(slice: &[u8]) -> ByteBuf { - slice.iter().cloned().take(BUF_SIZE).collect() -} diff --git a/mplex/src/write.rs b/mplex/src/write.rs deleted file mode 100644 index a33b1e90..00000000 --- a/mplex/src/write.rs +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright 2017 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any person obtaining a -// copy of this software and associated documentation files (the "Software"), -// to deal in the Software without restriction, including without limitation -// the rights to use, copy, modify, merge, publish, distribute, sublicense, -// and/or sell copies of the Software, and to permit persons to whom the -// Software is furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS -// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use header::MultiplexHeader; -use shared::{ByteBuf, MultiplexShared, SubstreamMetadata}; - -use circular_buffer; -use futures::task; -use std::io; -use tokio_io::AsyncWrite; -use varint; - -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub enum RequestType { - Meta, - Substream, -} - -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -pub struct WriteRequest { - header: MultiplexHeader, - request_type: RequestType, -} - -impl WriteRequest { - pub fn substream(header: MultiplexHeader) -> Self { - WriteRequest { - header, - request_type: RequestType::Substream, - } - } - - pub fn meta(header: MultiplexHeader) -> Self { - WriteRequest { - header, - request_type: RequestType::Meta, - } - } -} - -#[derive(Default, Debug)] -pub struct MultiplexWriteState { - current: Option<(WriteRequest, MultiplexWriteStateInner)>, - queued: Option, - // TODO: Actually close these - to_close: Vec, -} - -#[derive(Debug)] -pub enum MultiplexWriteStateInner { - WriteHeader { state: varint::EncoderState }, - BodyLength { state: varint::EncoderState }, - Body { size: usize }, -} - -pub fn write_stream( - lock: &mut MultiplexShared, - write_request: WriteRequest, - buf: &mut io::Cursor, -) -> io::Result { - use futures::Async; - use num_traits::cast::ToPrimitive; - use varint::WriteState; - use write::MultiplexWriteStateInner::*; - - let mut on_block = Err(io::ErrorKind::WouldBlock.into()); - let mut write_state = lock.write_state.take().unwrap_or_default(); - let (request, mut state) = write_state.current.take().unwrap_or_else(|| { - ( - write_request, - MultiplexWriteStateInner::WriteHeader { - state: varint::EncoderState::new(write_request.header.to_u64()), - }, - ) - }); - - let id = write_request.header.substream_id; - - if buf.get_ref().len() as u64 - buf.position() == 0 { - return Ok(0); - } - - match (request.request_type, write_request.request_type) { - (RequestType::Substream, RequestType::Substream) if request.header.substream_id != id => { - use std::mem; - - if let Some(cur) = lock.open_streams - .entry(id) - .or_insert_with(|| SubstreamMetadata::new_open()) - .open_meta_mut() - { - cur.write.push(task::current()); - } - - if let Some(tasks) = lock.open_streams - .get_mut(&request.header.substream_id) - .and_then(SubstreamMetadata::open_meta_mut) - .map(|cur| mem::replace(&mut cur.write, Default::default())) - { - for task in tasks { - task.notify(); - } - } - - lock.write_state = Some(write_state); - return on_block; - } - (RequestType::Substream, RequestType::Meta) => { - use std::mem; - - lock.write_state = Some(write_state); - lock.meta_write_tasks.push(task::current()); - - if let Some(tasks) = lock.open_streams - .get_mut(&request.header.substream_id) - .and_then(SubstreamMetadata::open_meta_mut) - .map(|cur| mem::replace(&mut cur.write, Default::default())) - { - for task in tasks { - task.notify(); - } - } - - return on_block; - } - (RequestType::Meta, RequestType::Substream) => { - use std::mem; - - lock.write_state = Some(write_state); - - if let Some(cur) = lock.open_streams - .entry(id) - .or_insert_with(|| SubstreamMetadata::new_open()) - .open_meta_mut() - { - cur.write.push(task::current()); - } - - for task in mem::replace(&mut lock.meta_write_tasks, Default::default()) { - task.notify(); - } - - return on_block; - } - _ => {} - } - - loop { - // Err = should return, Ok = continue - let new_state = match state { - WriteHeader { - state: mut inner_state, - } => match inner_state - .write(&mut lock.stream) - .map_err(|_| io::ErrorKind::Other)? - { - Async::Ready(WriteState::Done(_)) => Ok(BodyLength { - state: varint::EncoderState::new(buf.get_ref().len()), - }), - Async::Ready(WriteState::Pending(_)) | Async::NotReady => { - Err(Some(WriteHeader { state: inner_state })) - } - }, - BodyLength { - state: mut inner_state, - } => match inner_state - .write(&mut lock.stream) - .map_err(|_| io::ErrorKind::Other)? - { - Async::Ready(WriteState::Done(_)) => Ok(Body { - size: inner_state.source().to_usize().unwrap_or(::std::usize::MAX), - }), - Async::Ready(WriteState::Pending(_)) => Ok(BodyLength { state: inner_state }), - Async::NotReady => Err(Some(BodyLength { state: inner_state })), - }, - Body { size } => { - if buf.position() == buf.get_ref().len() as u64 { - Err(None) - } else { - match lock.stream.write(&buf.get_ref()[buf.position() as usize..]) { - Ok(just_written) => { - let cur_pos = buf.position(); - buf.set_position(cur_pos + just_written as u64); - on_block = Ok(on_block.unwrap_or(0) + just_written); - Ok(Body { - size: size - just_written, - }) - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - Err(Some(Body { size })) - } - Err(other) => { - debug!("substream {}: failed to write body: {}", id, other); - return Err(other); - } - } - } - } - }; - - match new_state { - Ok(new_state) => state = new_state, - Err(new_state) => { - write_state.current = new_state.map(|state| (request, state)); - lock.write_state = Some(write_state); - return on_block; - } - } - } -} diff --git a/mplex/tests/two_peers.rs b/mplex/tests/two_peers.rs index a455c62f..95a21f0c 100644 --- a/mplex/tests/two_peers.rs +++ b/mplex/tests/two_peers.rs @@ -42,7 +42,7 @@ fn client_to_server_outbound() { let bg_thread = thread::spawn(move || { let transport = - TcpConfig::new().with_upgrade(multiplex::MultiplexConfig::new()); + TcpConfig::new().with_upgrade(multiplex::MplexConfig::new()); let (listener, addr) = transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) @@ -70,7 +70,7 @@ fn client_to_server_outbound() { tokio_current_thread::block_on_all(future).unwrap(); }); - let transport = TcpConfig::new().with_upgrade(multiplex::MultiplexConfig::new()); + let transport = TcpConfig::new().with_upgrade(multiplex::MplexConfig::new()); let future = transport .dial(rx.recv().unwrap()) @@ -92,7 +92,7 @@ fn client_to_server_inbound() { let bg_thread = thread::spawn(move || { let transport = - TcpConfig::new().with_upgrade(multiplex::MultiplexConfig::new()); + TcpConfig::new().with_upgrade(multiplex::MplexConfig::new()); let (listener, addr) = transport .listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap()) @@ -120,7 +120,7 @@ fn client_to_server_inbound() { tokio_current_thread::block_on_all(future).unwrap(); }); - let transport = TcpConfig::new().with_upgrade(multiplex::MultiplexConfig::new()); + let transport = TcpConfig::new().with_upgrade(multiplex::MplexConfig::new()); let future = transport .dial(rx.recv().unwrap())