Merge pull request #1338 from romanb/multistream-select-io-fix

Fix possible incorrect return value from <LengthDelimitedReader as io::Write>::write().
This commit is contained in:
Pierre Krieger 2019-12-10 11:29:05 +01:00 committed by GitHub
commit 7745bfd01f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 64 deletions

View File

@ -323,27 +323,11 @@ where
R: AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
// Try to drain the write buffer together with writing `buf`.
if !self.inner.write_buffer.is_empty() {
let n = self.inner.write_buffer.len();
self.inner.write_buffer.extend_from_slice(buf);
let result = self.inner.poll_write_buffer();
let written = n - self.inner.write_buffer.len();
if written == 0 {
if let Err(e) = result {
return Err(e)
}
while !self.inner.write_buffer.is_empty() {
if self.inner.poll_write_buffer()?.is_not_ready() {
return Err(io::ErrorKind::WouldBlock.into())
}
if written < buf.len() {
if self.inner.write_buffer.len() > n {
self.inner.write_buffer.split_off(n); // Never grow the buffer.
}
return Ok(written)
}
return Ok(buf.len())
}
self.inner_mut().write(buf)
}

View File

@ -227,37 +227,14 @@ where
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match &mut self.state {
State::Completed { io, ref mut remaining } => {
if !remaining.is_empty() {
// Try to write `buf` together with `remaining` for efficiency,
// regardless of whether the underlying I/O stream is buffered.
// Every call to `write` may imply a syscall and separate
// network packet.
let remaining_len = remaining.len();
remaining.extend_from_slice(buf);
match io.write(&remaining) {
Err(e) => {
remaining.split_off(remaining_len);
Err(e)
}
Ok(n) => {
remaining.split_to(n);
if !remaining.is_empty() {
let written = if n < buf.len() {
remaining.split_off(remaining_len);
n
} else {
buf.len()
};
debug_assert!(remaining.len() <= remaining_len);
Ok(written)
} else {
Ok(buf.len())
}
}
while !remaining.is_empty() {
let n = io.write(&remaining)?;
if n == 0 {
return Err(io::ErrorKind::WriteZero.into())
}
} else {
io.write(buf)
remaining.split_to(n);
}
io.write(buf)
},
State::Expecting { io, .. } => io.write(buf),
State::Invalid => panic!("Negotiated: Invalid state")
@ -382,44 +359,40 @@ mod tests {
#[test]
fn write_remaining() {
fn prop(rem: Vec<u8>, new: Vec<u8>, free: u8) -> TestResult {
fn prop(rem: Vec<u8>, new: Vec<u8>, free: u8, step: u8) -> TestResult {
let cap = rem.len() + free as usize;
let buf = Capped { buf: Vec::with_capacity(cap), step: free as usize };
let mut rem = BytesMut::from(rem);
let step = u8::min(free, step) as usize + 1;
let buf = Capped { buf: Vec::with_capacity(cap), step };
let rem = BytesMut::from(rem);
let mut io = Negotiated::completed(buf, rem.clone());
let mut written = 0;
loop {
// Write until `new` has been fully written or the capped buffer is
// full (in which case the buffer should remain unchanged from the
// last successful write).
// Write until `new` has been fully written or the capped buffer runs
// over capacity and yields WriteZero.
match io.write(&new[written..]) {
Ok(n) =>
if let State::Completed { remaining, .. } = &io.state {
if n == rem.len() + new[written..].len() {
assert!(remaining.is_empty())
} else {
assert!(remaining.len() <= rem.len());
}
assert!(remaining.is_empty());
written += n;
if written == new.len() {
return TestResult::passed()
}
rem = remaining.clone();
} else {
return TestResult::failed()
}
Err(_) =>
if let State::Completed { remaining, .. } = &io.state {
assert!(rem.len() + new[written..].len() > cap);
assert_eq!(remaining, &rem);
Err(e) if e.kind() == io::ErrorKind::WriteZero => {
if let State::Completed { .. } = &io.state {
assert!(rem.len() + new.len() > cap);
return TestResult::passed()
} else {
return TestResult::failed()
}
}
Err(e) => panic!("Unexpected error: {:?}", e)
}
}
}
quickcheck(prop as fn(_,_,_) -> _)
quickcheck(prop as fn(_,_,_,_) -> _)
}
}