From 3de85ba4e2339ee1b22fa09139d69055a2b2fc4b Mon Sep 17 00:00:00 2001 From: "Roman S. Borschel" Date: Mon, 9 Dec 2019 15:27:01 +0100 Subject: [PATCH 1/2] Fix possible incorrect return value from LengthDelimitedReader::write(). Due to not taking into account buf.len() when computing `written`, it may be incorrectly less than buf.len(). --- misc/multistream-select/src/length_delimited.rs | 8 ++++---- misc/multistream-select/src/negotiated.rs | 14 ++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 5d22fb10..389b8e82 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -328,17 +328,17 @@ where let n = self.inner.write_buffer.len(); self.inner.write_buffer.extend_from_slice(buf); let result = self.inner.poll_write_buffer(); - let written = n - self.inner.write_buffer.len(); + let written = n + buf.len() - self.inner.write_buffer.len(); if written == 0 { + self.inner.write_buffer.split_off(n); // Never grow the buffer. if let Err(e) = result { return Err(e) } return Err(io::ErrorKind::WouldBlock.into()) } if written < buf.len() { - if self.inner.write_buffer.len() > n { - self.inner.write_buffer.split_off(n); // Never grow the buffer. - } + debug_assert!(self.inner.write_buffer.len() > n); + self.inner.write_buffer.split_off(n); // Never grow the buffer. return Ok(written) } return Ok(buf.len()) diff --git a/misc/multistream-select/src/negotiated.rs b/misc/multistream-select/src/negotiated.rs index 3519d6cc..e0abce10 100644 --- a/misc/multistream-select/src/negotiated.rs +++ b/misc/multistream-select/src/negotiated.rs @@ -242,12 +242,14 @@ where Ok(n) => { remaining.split_to(n); if !remaining.is_empty() { - let written = if n < buf.len() { - remaining.split_off(remaining_len); - n - } else { - buf.len() - }; + let written = + if n < buf.len() { + debug_assert!(remaining.len() > remaining_len); + remaining.split_off(remaining_len); + n + } else { + buf.len() + }; debug_assert!(remaining.len() <= remaining_len); Ok(written) } else { From e4f46aed06bc749d6380ead7a39961100973016b Mon Sep 17 00:00:00 2001 From: "Roman S. Borschel" Date: Mon, 9 Dec 2019 17:51:47 +0100 Subject: [PATCH 2/2] Simplify and update test. Remove the optimisation of writing data out together with any remaining buffer for simplicity. --- .../src/length_delimited.rs | 20 +----- misc/multistream-select/src/negotiated.rs | 67 ++++++------------- 2 files changed, 21 insertions(+), 66 deletions(-) diff --git a/misc/multistream-select/src/length_delimited.rs b/misc/multistream-select/src/length_delimited.rs index 389b8e82..91e3fe88 100644 --- a/misc/multistream-select/src/length_delimited.rs +++ b/misc/multistream-select/src/length_delimited.rs @@ -323,27 +323,11 @@ where R: AsyncWrite { fn write(&mut self, buf: &[u8]) -> io::Result { - // Try to drain the write buffer together with writing `buf`. - if !self.inner.write_buffer.is_empty() { - let n = self.inner.write_buffer.len(); - self.inner.write_buffer.extend_from_slice(buf); - let result = self.inner.poll_write_buffer(); - let written = n + buf.len() - self.inner.write_buffer.len(); - if written == 0 { - self.inner.write_buffer.split_off(n); // Never grow the buffer. - if let Err(e) = result { - return Err(e) - } + while !self.inner.write_buffer.is_empty() { + if self.inner.poll_write_buffer()?.is_not_ready() { return Err(io::ErrorKind::WouldBlock.into()) } - if written < buf.len() { - debug_assert!(self.inner.write_buffer.len() > n); - self.inner.write_buffer.split_off(n); // Never grow the buffer. - return Ok(written) - } - return Ok(buf.len()) } - self.inner_mut().write(buf) } diff --git a/misc/multistream-select/src/negotiated.rs b/misc/multistream-select/src/negotiated.rs index e0abce10..5e2c7ac9 100644 --- a/misc/multistream-select/src/negotiated.rs +++ b/misc/multistream-select/src/negotiated.rs @@ -227,39 +227,14 @@ where fn write(&mut self, buf: &[u8]) -> io::Result { match &mut self.state { State::Completed { io, ref mut remaining } => { - if !remaining.is_empty() { - // Try to write `buf` together with `remaining` for efficiency, - // regardless of whether the underlying I/O stream is buffered. - // Every call to `write` may imply a syscall and separate - // network packet. - let remaining_len = remaining.len(); - remaining.extend_from_slice(buf); - match io.write(&remaining) { - Err(e) => { - remaining.split_off(remaining_len); - Err(e) - } - Ok(n) => { - remaining.split_to(n); - if !remaining.is_empty() { - let written = - if n < buf.len() { - debug_assert!(remaining.len() > remaining_len); - remaining.split_off(remaining_len); - n - } else { - buf.len() - }; - debug_assert!(remaining.len() <= remaining_len); - Ok(written) - } else { - Ok(buf.len()) - } - } + while !remaining.is_empty() { + let n = io.write(&remaining)?; + if n == 0 { + return Err(io::ErrorKind::WriteZero.into()) } - } else { - io.write(buf) + remaining.split_to(n); } + io.write(buf) }, State::Expecting { io, .. } => io.write(buf), State::Invalid => panic!("Negotiated: Invalid state") @@ -384,44 +359,40 @@ mod tests { #[test] fn write_remaining() { - fn prop(rem: Vec, new: Vec, free: u8) -> TestResult { + fn prop(rem: Vec, new: Vec, free: u8, step: u8) -> TestResult { let cap = rem.len() + free as usize; - let buf = Capped { buf: Vec::with_capacity(cap), step: free as usize }; - let mut rem = BytesMut::from(rem); + let step = u8::min(free, step) as usize + 1; + let buf = Capped { buf: Vec::with_capacity(cap), step }; + let rem = BytesMut::from(rem); let mut io = Negotiated::completed(buf, rem.clone()); let mut written = 0; loop { - // Write until `new` has been fully written or the capped buffer is - // full (in which case the buffer should remain unchanged from the - // last successful write). + // Write until `new` has been fully written or the capped buffer runs + // over capacity and yields WriteZero. match io.write(&new[written..]) { Ok(n) => if let State::Completed { remaining, .. } = &io.state { - if n == rem.len() + new[written..].len() { - assert!(remaining.is_empty()) - } else { - assert!(remaining.len() <= rem.len()); - } + assert!(remaining.is_empty()); written += n; if written == new.len() { return TestResult::passed() } - rem = remaining.clone(); } else { return TestResult::failed() } - Err(_) => - if let State::Completed { remaining, .. } = &io.state { - assert!(rem.len() + new[written..].len() > cap); - assert_eq!(remaining, &rem); + Err(e) if e.kind() == io::ErrorKind::WriteZero => { + if let State::Completed { .. } = &io.state { + assert!(rem.len() + new.len() > cap); return TestResult::passed() } else { return TestResult::failed() } + } + Err(e) => panic!("Unexpected error: {:?}", e) } } } - quickcheck(prop as fn(_,_,_) -> _) + quickcheck(prop as fn(_,_,_,_) -> _) } }