From 624d2238f9e3b7d277acc4708f71a0b7cb9a3bf6 Mon Sep 17 00:00:00 2001 From: Vurich Date: Mon, 11 Dec 2017 17:57:11 +0100 Subject: [PATCH] Set task correctly when not blocked --- multiplex-rs/src/lib.rs | 41 +++++++++++++++++++--------------------- multiplex-rs/src/read.rs | 26 ++++++++++++++----------- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/multiplex-rs/src/lib.rs b/multiplex-rs/src/lib.rs index 2393ddee..1d982390 100644 --- a/multiplex-rs/src/lib.rs +++ b/multiplex-rs/src/lib.rs @@ -356,6 +356,7 @@ where mod tests { use super::*; use std::io; + use tokio_io::io as tokio; #[test] fn can_use_one_stream() { @@ -367,7 +368,7 @@ mod tests { let mut substream = mplex.clone().outbound().wait().unwrap(); - assert!(substream.write(message).is_ok()); + assert!(tokio::write_all(&mut substream, message).wait().is_ok()); let id = substream.id(); @@ -394,7 +395,7 @@ mod tests { let mut buf = vec![0; message.len()]; - assert!(substream.read(&mut buf).is_ok()); + assert!(tokio::read(&mut substream, &mut buf).wait().is_ok()); assert_eq!(&buf, message); } @@ -404,35 +405,31 @@ mod tests { let mplex = Multiplex::dial(stream); - let mut outbound: Vec> = vec![ - mplex.clone().outbound().wait().unwrap(), - mplex.clone().outbound().wait().unwrap(), - mplex.clone().outbound().wait().unwrap(), - mplex.clone().outbound().wait().unwrap(), - mplex.clone().outbound().wait().unwrap(), - ]; + let mut outbound: Vec> = vec![]; + + for _ in 0..5 { + outbound.push(mplex.clone().outbound().wait().unwrap()); + } outbound.sort_by_key(|a| a.id()); for (i, substream) in outbound.iter_mut().enumerate() { - assert!(substream.write(i.to_string().as_bytes()).is_ok()); + assert!(tokio::write_all(substream, i.to_string().as_bytes()).wait().is_ok()); } let stream = io::Cursor::new(mplex.state.lock().stream.get_ref().clone()); let mplex = Multiplex::listen(stream); - let mut inbound: Vec> = vec![ - mplex.clone().inbound().wait().unwrap(), - mplex.clone().inbound().wait().unwrap(), - mplex.clone().inbound().wait().unwrap(), - mplex.clone().inbound().wait().unwrap(), - mplex.clone().inbound().wait().unwrap(), - ]; + let mut inbound: Vec> = vec![]; + + for _ in 0..5 { + inbound.push(mplex.clone().inbound().wait().unwrap()); + } inbound.sort_by_key(|a| a.id()); - for (substream, outbound) in inbound.iter_mut().zip(outbound.iter()) { + for (mut substream, outbound) in inbound.iter_mut().zip(outbound.iter()) { let id = outbound.id(); assert_eq!(id, substream.id()); assert_eq!( @@ -443,7 +440,7 @@ mod tests { ); let mut buf = [0; 3]; - assert_eq!(substream.read(&mut buf).unwrap(), 1); + assert_eq!(tokio::read(&mut substream, &mut buf).wait().unwrap().2, 1); } } @@ -494,7 +491,7 @@ mod tests { let mut buf = vec![0; message.len()]; - assert!(substream.read(&mut buf).is_ok()); + assert!(tokio::read(&mut substream, &mut buf).wait().is_ok()); assert_eq!(&buf, message); } @@ -541,7 +538,7 @@ mod tests { assert_eq!(substream.id(), 0); assert_eq!(substream.name(), None); - assert_eq!(substream.read(&mut [0; 100][..]).unwrap(), 0); + assert_eq!(tokio::read(&mut substream, &mut [0; 100][..]).wait().unwrap().2, 0); } #[test] @@ -589,7 +586,7 @@ mod tests { for _ in 0..20 { let mut buf = [0; 1]; - assert_eq!(substream.read(&mut buf[..]).unwrap(), 1); + assert_eq!(tokio::read(&mut substream, &mut buf[..]).wait().unwrap().2, 1); out.push(buf[0]); } diff --git a/multiplex-rs/src/read.rs b/multiplex-rs/src/read.rs index 5b67649b..52743a56 100644 --- a/multiplex-rs/src/read.rs +++ b/multiplex-rs/src/read.rs @@ -95,6 +95,21 @@ pub fn read_stream<'a, O: Into>, T: AsyncRead>( Err(io::ErrorKind::WouldBlock.into()) }; + if let Some((ref mut id, ..)) = stream_data { + // TODO: We can do this only hashing `id` once using the entry API. + let write = lock.open_streams + .get(id) + .and_then(SubstreamMetadata::write_task) + .cloned(); + lock.open_streams.insert( + *id, + SubstreamMetadata::Open { + read: Some(task::current()), + write, + }, + ); + } + loop { match lock.read_state.take().unwrap_or_default() { Header { @@ -341,17 +356,6 @@ pub fn read_stream<'a, O: Into>, T: AsyncRead>( task.notify(); } - let write = lock.open_streams - .get(id) - .and_then(SubstreamMetadata::write_task) - .cloned(); - lock.open_streams.insert( - *id, - SubstreamMetadata::Open { - read: Some(task::current()), - write, - }, - ); // We cannot make progress here, another stream has to accept this packet return on_block; }