diff --git a/multiplex-rs/src/lib.rs b/multiplex-rs/src/lib.rs index 1d982390..8c47bf26 100644 --- a/multiplex-rs/src/lib.rs +++ b/multiplex-rs/src/lib.rs @@ -356,7 +356,7 @@ where mod tests { use super::*; use std::io; - use tokio_io::io as tokio; + use tokio_io::io as tokio; #[test] fn can_use_one_stream() { @@ -414,7 +414,11 @@ mod tests { outbound.sort_by_key(|a| a.id()); for (i, substream) in outbound.iter_mut().enumerate() { - assert!(tokio::write_all(substream, i.to_string().as_bytes()).wait().is_ok()); + assert!( + tokio::write_all(substream, i.to_string().as_bytes()) + .wait() + .is_ok() + ); } let stream = io::Cursor::new(mplex.state.lock().stream.get_ref().clone()); @@ -538,7 +542,13 @@ mod tests { assert_eq!(substream.id(), 0); assert_eq!(substream.name(), None); - assert_eq!(tokio::read(&mut substream, &mut [0; 100][..]).wait().unwrap().2, 0); + assert_eq!( + tokio::read(&mut substream, &mut [0; 100][..]) + .wait() + .unwrap() + .2, + 0 + ); } #[test] @@ -586,7 +596,10 @@ mod tests { for _ in 0..20 { let mut buf = [0; 1]; - assert_eq!(tokio::read(&mut substream, &mut buf[..]).wait().unwrap().2, 1); + assert_eq!( + tokio::read(&mut substream, &mut buf[..]).wait().unwrap().2, + 1 + ); out.push(buf[0]); } diff --git a/multiplex-rs/src/read.rs b/multiplex-rs/src/read.rs index 52743a56..5df00b14 100644 --- a/multiplex-rs/src/read.rs +++ b/multiplex-rs/src/read.rs @@ -81,6 +81,7 @@ pub fn read_stream<'a, O: Into>, T: AsyncRead>( stream_data: O, ) -> io::Result { use self::MultiplexReadState::*; + use std::mem; let mut stream_data = stream_data.into(); let stream_has_been_gracefully_closed = stream_data @@ -95,19 +96,14 @@ pub fn read_stream<'a, O: Into>, T: AsyncRead>( Err(io::ErrorKind::WouldBlock.into()) }; - if let Some((ref mut id, ..)) = stream_data { - // TODO: We can do this only hashing `id` once using the entry API. - let write = lock.open_streams - .get(id) - .and_then(SubstreamMetadata::write_task) - .cloned(); - lock.open_streams.insert( - *id, - SubstreamMetadata::Open { - read: Some(task::current()), - write, - }, - ); + if let Some((ref id, ..)) = stream_data { + if let Some(cur) = lock.open_streams + .entry(*id) + .or_insert_with(|| SubstreamMetadata::new_open()) + .read_tasks_mut() + { + cur.push(task::current()); + } } loop { @@ -349,11 +345,14 @@ pub fn read_stream<'a, O: Into>, T: AsyncRead>( remaining_bytes, }); - if let Some(task) = lock.open_streams - .get(&substream_id) - .and_then(SubstreamMetadata::read_task) + if let Some(tasks) = lock.open_streams + .get_mut(&substream_id) + .and_then(SubstreamMetadata::read_tasks_mut) + .map(|cur| mem::replace(cur, Default::default())) { - task.notify(); + for task in tasks { + task.notify(); + } } // We cannot make progress here, another stream has to accept this packet @@ -365,6 +364,16 @@ pub fn read_stream<'a, O: Into>, T: AsyncRead>( remaining_bytes, }); + if let Some(tasks) = lock.open_streams + .get_mut(&substream_id) + .and_then(SubstreamMetadata::read_tasks_mut) + .map(|cur| mem::replace(cur, Default::default())) + { + for task in tasks { + task.notify(); + } + } + // We cannot make progress here, a stream has to accept this packet return on_block; } diff --git a/multiplex-rs/src/shared.rs b/multiplex-rs/src/shared.rs index 172bfd34..9d4fce22 100644 --- a/multiplex-rs/src/shared.rs +++ b/multiplex-rs/src/shared.rs @@ -32,13 +32,17 @@ pub type ByteBuf = ArrayVec<[u8; BUF_SIZE]>; pub enum SubstreamMetadata { Closed, - Open { - read: Option, - write: Option, - }, + Open { read: Vec, write: Vec }, } impl SubstreamMetadata { + pub fn new_open() -> Self { + SubstreamMetadata::Open { + read: Default::default(), + write: Default::default(), + } + } + pub fn open(&self) -> bool { match *self { SubstreamMetadata::Closed => false, @@ -46,17 +50,17 @@ impl SubstreamMetadata { } } - pub fn read_task(&self) -> Option<&Task> { + pub fn read_tasks_mut(&mut self) -> Option<&mut Vec> { match *self { SubstreamMetadata::Closed => None, - SubstreamMetadata::Open { ref read, .. } => read.as_ref(), + SubstreamMetadata::Open { ref mut read, .. } => Some(read), } } - pub fn write_task(&self) -> Option<&Task> { + pub fn write_tasks_mut(&mut self) -> Option<&mut Vec> { match *self { SubstreamMetadata::Closed => None, - SubstreamMetadata::Open { ref write, .. } => write.as_ref(), + SubstreamMetadata::Open { ref mut write, .. } => Some(write), } } } @@ -72,8 +76,10 @@ pub struct MultiplexShared { pub stream: T, // true if the stream is open, false otherwise pub open_streams: HashMap, + pub meta_write_tasks: Vec, // TODO: Should we use a version of this with a fixed size that doesn't allocate and return - // `WouldBlock` if it's full? + // `WouldBlock` if it's full? Even if we ignore or size-cap names you can still open 2^32 + // streams. pub to_open: HashMap>, } @@ -83,6 +89,7 @@ impl MultiplexShared { read_state: Default::default(), write_state: Default::default(), open_streams: Default::default(), + meta_write_tasks: Default::default(), to_open: Default::default(), stream: stream, } @@ -92,8 +99,8 @@ impl MultiplexShared { self.open_streams .entry(id) .or_insert(SubstreamMetadata::Open { - read: None, - write: None, + read: Default::default(), + write: Default::default(), }) .open() } diff --git a/multiplex-rs/src/write.rs b/multiplex-rs/src/write.rs index 3b702a67..ab5c151f 100644 --- a/multiplex-rs/src/write.rs +++ b/multiplex-rs/src/write.rs @@ -94,31 +94,64 @@ pub fn write_stream( match (request.request_type, write_request.request_type) { (RequestType::Substream, RequestType::Substream) if request.header.substream_id != id => { - let read = lock.open_streams - .get(&id) - .and_then(SubstreamMetadata::read_task) - .cloned(); + use std::mem; - if let Some(task) = lock.open_streams - .get(&request.header.substream_id) - .and_then(SubstreamMetadata::write_task) + if let Some(cur) = lock.open_streams + .entry(id) + .or_insert_with(|| SubstreamMetadata::new_open()) + .write_tasks_mut() { - task.notify(); + cur.push(task::current()); + } + + if let Some(tasks) = lock.open_streams + .get_mut(&request.header.substream_id) + .and_then(SubstreamMetadata::write_tasks_mut) + .map(|cur| mem::replace(cur, Default::default())) + { + for task in tasks { + task.notify(); + } } - lock.open_streams.insert( - id, - SubstreamMetadata::Open { - write: Some(task::current()), - read, - }, - ); lock.write_state = Some(write_state); return on_block; } - (RequestType::Substream, RequestType::Meta) - | (RequestType::Meta, RequestType::Substream) => { + (RequestType::Substream, RequestType::Meta) => { + use std::mem; + lock.write_state = Some(write_state); + lock.meta_write_tasks.push(task::current()); + + if let Some(tasks) = lock.open_streams + .get_mut(&request.header.substream_id) + .and_then(SubstreamMetadata::write_tasks_mut) + .map(|cur| mem::replace(cur, Default::default())) + { + for task in tasks { + task.notify(); + } + } + + return on_block; + } + (RequestType::Meta, RequestType::Substream) => { + use std::mem; + + lock.write_state = Some(write_state); + + if let Some(cur) = lock.open_streams + .entry(id) + .or_insert_with(|| SubstreamMetadata::new_open()) + .write_tasks_mut() + { + cur.push(task::current()); + } + + for task in mem::replace(&mut lock.meta_write_tasks, Default::default()) { + task.notify(); + } + return on_block; } _ => {}