diff --git a/misc/multistream-select/src/negotiated.rs b/misc/multistream-select/src/negotiated.rs index 3519d6cc..5e2c7ac9 100644 --- a/misc/multistream-select/src/negotiated.rs +++ b/misc/multistream-select/src/negotiated.rs @@ -227,37 +227,14 @@ where fn write(&mut self, buf: &[u8]) -> io::Result { match &mut self.state { State::Completed { io, ref mut remaining } => { - if !remaining.is_empty() { - // Try to write `buf` together with `remaining` for efficiency, - // regardless of whether the underlying I/O stream is buffered. - // Every call to `write` may imply a syscall and separate - // network packet. - let remaining_len = remaining.len(); - remaining.extend_from_slice(buf); - match io.write(&remaining) { - Err(e) => { - remaining.split_off(remaining_len); - Err(e) - } - Ok(n) => { - remaining.split_to(n); - if !remaining.is_empty() { - let written = if n < buf.len() { - remaining.split_off(remaining_len); - n - } else { - buf.len() - }; - debug_assert!(remaining.len() <= remaining_len); - Ok(written) - } else { - Ok(buf.len()) - } - } + while !remaining.is_empty() { + let n = io.write(&remaining)?; + if n == 0 { + return Err(io::ErrorKind::WriteZero.into()) } - } else { - io.write(buf) + remaining.split_to(n); } + io.write(buf) }, State::Expecting { io, .. } => io.write(buf), State::Invalid => panic!("Negotiated: Invalid state") @@ -382,44 +359,40 @@ mod tests { #[test] fn write_remaining() { - fn prop(rem: Vec, new: Vec, free: u8) -> TestResult { + fn prop(rem: Vec, new: Vec, free: u8, step: u8) -> TestResult { let cap = rem.len() + free as usize; - let buf = Capped { buf: Vec::with_capacity(cap), step: free as usize }; - let mut rem = BytesMut::from(rem); + let step = u8::min(free, step) as usize + 1; + let buf = Capped { buf: Vec::with_capacity(cap), step }; + let rem = BytesMut::from(rem); let mut io = Negotiated::completed(buf, rem.clone()); let mut written = 0; loop { - // Write until `new` has been fully written or the capped buffer is - // full (in which case the buffer should remain unchanged from the - // last successful write). + // Write until `new` has been fully written or the capped buffer runs + // over capacity and yields WriteZero. match io.write(&new[written..]) { Ok(n) => if let State::Completed { remaining, .. } = &io.state { - if n == rem.len() + new[written..].len() { - assert!(remaining.is_empty()) - } else { - assert!(remaining.len() <= rem.len()); - } + assert!(remaining.is_empty()); written += n; if written == new.len() { return TestResult::passed() } - rem = remaining.clone(); } else { return TestResult::failed() } - Err(_) => - if let State::Completed { remaining, .. } = &io.state { - assert!(rem.len() + new[written..].len() > cap); - assert_eq!(remaining, &rem); + Err(e) if e.kind() == io::ErrorKind::WriteZero => { + if let State::Completed { .. } = &io.state { + assert!(rem.len() + new.len() > cap); return TestResult::passed() } else { return TestResult::failed() } + } + Err(e) => panic!("Unexpected error: {:?}", e) } } } - quickcheck(prop as fn(_,_,_) -> _) + quickcheck(prop as fn(_,_,_,_) -> _) } }