diff --git a/protocols/noise/src/io.rs b/protocols/noise/src/io.rs index 472bcefd..ad7f541f 100644 --- a/protocols/noise/src/io.rs +++ b/protocols/noise/src/io.rs @@ -22,7 +22,7 @@ pub mod handshake; -use futures::Poll; +use futures::{Async, Poll}; use log::{debug, trace}; use snow; use snow::error::{StateProblem, Error as SnowError}; @@ -313,7 +313,7 @@ impl io::Write for NoiseOutput { let buffer = self.buffer.borrow_mut(); loop { match self.write_state { - WriteState::Init => return Ok(()), + WriteState::Init => return self.io.flush(), WriteState::BufferData { off } => { trace!("flush: encrypting {} bytes", off); if let Ok(n) = self.session.write_message(&buffer.write[.. off], buffer.write_crypto) { @@ -360,7 +360,6 @@ impl io::Write for NoiseOutput { if len == *off { trace!("flush: finished writing {} bytes", len); self.write_state = WriteState::Init; - return Ok(()) } } WriteState::Eof => { @@ -381,7 +380,11 @@ impl AsyncRead for NoiseOutput { impl AsyncWrite for NoiseOutput { fn shutdown(&mut self) -> Poll<(), io::Error> { - self.io.shutdown() + match io::Write::flush(self) { + Ok(_) => self.io.shutdown(), + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Ok(Async::NotReady), + Err(e) => Err(e), + } } }