diff --git a/core/src/upgrade/traits.rs b/core/src/upgrade/traits.rs index 774a7d5f..51999246 100644 --- a/core/src/upgrade/traits.rs +++ b/core/src/upgrade/traits.rs @@ -20,10 +20,10 @@ use bytes::Bytes; use futures::future::Future; -use std::io::Error as IoError; +use std::{io::Error as IoError, ops::Not}; /// Type of connection for the upgrade. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum Endpoint { /// The socket comes from a dialer. Dialer, @@ -31,6 +31,18 @@ pub enum Endpoint { Listener, } +impl Not for Endpoint { + type Output = Endpoint; + + fn not(self) -> Self::Output { + match self { + Endpoint::Dialer => Endpoint::Listener, + Endpoint::Listener => Endpoint::Dialer + } + } +} + + /// Implemented on structs that describe a possible upgrade to a connection between two peers. /// /// The generic `C` is the type of the incoming connection before it is upgraded. diff --git a/muxers/mplex/src/codec.rs b/muxers/mplex/src/codec.rs index e7106a58..b645a0fe 100644 --- a/muxers/mplex/src/codec.rs +++ b/muxers/mplex/src/codec.rs @@ -49,6 +49,15 @@ impl Elem { } } + pub fn endpoint(&self) -> Option { + match *self { + Elem::Open { .. } => None, + Elem::Data { endpoint, .. } => Some(endpoint), + Elem::Close { endpoint, .. } => Some(endpoint), + Elem::Reset { endpoint, .. } => Some(endpoint) + } + } + /// Returns true if this message is `Close` or `Reset`. #[inline] pub fn is_close_or_reset_msg(&self) -> bool { diff --git a/muxers/mplex/src/lib.rs b/muxers/mplex/src/lib.rs index 709ed007..b6a633ff 100644 --- a/muxers/mplex/src/lib.rs +++ b/muxers/mplex/src/lib.rs @@ -171,7 +171,9 @@ struct MultiplexInner { buffer: Vec, // List of Ids of opened substreams. Used to filter out messages that don't belong to any // substream. Note that this is handled exclusively by `next_match`. - opened_substreams: FnvHashSet, + // The `Endpoint` value denotes who initiated the substream from our point of view + // (see note [StreamId]). + opened_substreams: FnvHashSet<(u32, Endpoint)>, // Id of the next outgoing substream. Should always increase by two. next_outbound_stream_id: u32, /// List of tasks to notify when a read event happens on the underlying stream. @@ -200,6 +202,16 @@ task_local!{ static TASK_ID: usize = NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed) } +// Note [StreamId]: mplex no longer partitions stream IDs into odd (for initiators) and +// even ones (for receivers). Streams are instead identified by a number and whether the flag +// is odd (for receivers) or even (for initiators). `Open` frames do not have a flag, but are +// sent unidirectional. As a consequence, we need to remember if the stream was initiated by us +// or remotely and we store the information from our point of view, i.e. receiving an `Open` frame +// is stored as `(, Listener)`, sending an `Open` frame as `(, Dialer)`. Receiving +// a `Data` frame with flag `MessageReceiver` (= 1) means that we initiated the stream, so the +// entry has been stored as `(, Dialer)`. So, when looking up streams based on frames +// received, we have to invert the `Endpoint`, except for `Open`. + /// Processes elements in `inner` until one matching `filter` is found. /// /// If `NotReady` is returned, the current task is scheduled for later, just like with any `Poll`. @@ -259,25 +271,21 @@ where C: AsyncRead + AsyncWrite, // Handle substreams opening/closing. match elem { codec::Elem::Open { substream_id } => { - if (substream_id % 2) == (inner.next_outbound_stream_id % 2) { - inner.error = Err(IoError::new(IoErrorKind::Other, "invalid substream id opened")); - return Err(IoError::new(IoErrorKind::Other, "invalid substream id opened")); - } - - if !inner.opened_substreams.insert(substream_id) { + if !inner.opened_substreams.insert((substream_id, Endpoint::Listener)) { debug!("Received open message for substream {} which was already open", substream_id) } - }, - codec::Elem::Close { substream_id, .. } | codec::Elem::Reset { substream_id, .. } => { - inner.opened_substreams.remove(&substream_id); - }, + } + codec::Elem::Close { substream_id, endpoint, .. } | codec::Elem::Reset { substream_id, endpoint, .. } => { + inner.opened_substreams.remove(&(substream_id, !endpoint)); + } _ => () } if let Some(out) = filter(&elem) { return Ok(Async::Ready(Some(out))); } else { - if inner.opened_substreams.contains(&elem.substream_id()) || elem.is_open_msg() { + let endpoint = elem.endpoint().unwrap_or(Endpoint::Dialer); + if inner.opened_substreams.contains(&(elem.substream_id(), !endpoint)) || elem.is_open_msg() { inner.buffer.push(elem); } else if !elem.is_close_or_reset_msg() { debug!("Ignored message {:?} because the substream wasn't open", elem); @@ -346,7 +354,7 @@ where C: AsyncRead + AsyncWrite n }; - inner.opened_substreams.insert(substream_id); + inner.opened_substreams.insert((substream_id, Endpoint::Dialer)); OutboundSubstream { num: substream_id, @@ -423,10 +431,12 @@ where C: AsyncRead + AsyncWrite let mut inner = self.inner.lock(); let next_data_poll = next_match(&mut inner, |elem| { match elem { - &codec::Elem::Data { ref substream_id, ref data, .. } if *substream_id == substream.num => { // TODO: check endpoint? + codec::Elem::Data { substream_id, endpoint, data, .. } + if *substream_id == substream.num && *endpoint != substream.endpoint => // see note [StreamId] + { Some(data.clone()) - }, - _ => None, + } + _ => None } }); @@ -438,7 +448,7 @@ where C: AsyncRead + AsyncWrite Ok(Async::NotReady) => { // There was no data packet in the buffer about this substream ; maybe it's // because it has been closed. - if inner.opened_substreams.contains(&substream.num) { + if inner.opened_substreams.contains(&(substream.num, substream.endpoint)) { return Err(IoErrorKind::WouldBlock.into()); } else { return Ok(0);