From a266b1e72483d0b30df2f69c1caaf8a35295f04f Mon Sep 17 00:00:00 2001 From: Roman Borschel Date: Wed, 10 Apr 2019 17:54:24 +0200 Subject: [PATCH] Patch reading/writing frame lengths in libp2p-noise. (#1050) * Patch reading/writing frame lengths in libp2p-noise. Extracted from https://github.com/libp2p/rust-libp2p/pull/1027 since its fate it still undetermined. * Fix formatting. --- protocols/noise/src/io.rs | 122 +++++++++++++++++++++++++++----------- 1 file changed, 86 insertions(+), 36 deletions(-) diff --git a/protocols/noise/src/io.rs b/protocols/noise/src/io.rs index e364c0b7..5fefbe30 100644 --- a/protocols/noise/src/io.rs +++ b/protocols/noise/src/io.rs @@ -94,6 +94,8 @@ impl Handshake { } /// A noise session to a remote. +/// +/// `T` is the type of the underlying I/O resource. pub struct NoiseOutput { io: T, session: snow::Session, @@ -127,6 +129,8 @@ impl NoiseOutput { enum ReadState { /// initial state Init, + /// read frame length + ReadLen { buf: [u8; 2], off: usize }, /// read encrypted frame data ReadData { len: usize, off: usize }, /// copy decrypted frame data @@ -146,7 +150,7 @@ enum WriteState { /// accumulate write data BufferData { off: usize }, /// write frame length - WriteLen { len: usize }, + WriteLen { len: usize, buf: [u8; 2], off: usize }, /// write out encrypted data WriteData { len: usize, off: usize }, /// end of file has been reached (terminal state) @@ -162,17 +166,28 @@ impl io::Read for NoiseOutput { trace!("read state: {:?}", self.read_state); match self.read_state { ReadState::Init => { - let n = match read_frame_len(&mut self.io)? { - Some(n) => n, - None => { + self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; + } + ReadState::ReadLen { mut buf, mut off } => { + let n = match read_frame_len(&mut self.io, &mut buf, &mut off) { + Ok(Some(n)) => n, + Ok(None) => { trace!("read: eof"); self.read_state = ReadState::Eof(Ok(())); return Ok(0) } + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + // Preserve read state + self.read_state = ReadState::ReadLen { buf, off }; + } + return Err(e) + } }; trace!("read: next frame len = {}", n); if n == 0 { trace!("read: empty frame"); + self.read_state = ReadState::Init; continue } self.read_state = ReadState::ReadData { len: usize::from(n), off: 0 } @@ -204,7 +219,7 @@ impl io::Read for NoiseOutput { trace!("read: copied {}/{} bytes", *off + n, len); *off += n; if len == *off { - self.read_state = ReadState::Init + self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 }; } return Ok(n) } @@ -240,7 +255,11 @@ impl io::Write for NoiseOutput { trace!("write: encrypting {} bytes", *off); if let Ok(n) = self.session.write_message(buffer.write, buffer.write_crypto) { trace!("write: cipher text len = {} bytes", n); - self.write_state = WriteState::WriteLen { len: n } + self.write_state = WriteState::WriteLen { + len: n, + buf: u16::to_be_bytes(n as u16), + off: 0 + } } else { debug!("encryption error"); self.write_state = WriteState::EncErr; @@ -249,12 +268,21 @@ impl io::Write for NoiseOutput { } return Ok(n) } - WriteState::WriteLen { len } => { - trace!("write: writing len ({})", len); - if !write_frame_len(&mut self.io, len as u16)? { - trace!("write: eof"); - self.write_state = WriteState::Eof; - return Err(io::ErrorKind::WriteZero.into()) + WriteState::WriteLen { len, mut buf, mut off } => { + trace!("write: writing len ({}, {:?}, {}/2)", len, buf, off); + match write_frame_len(&mut self.io, &mut buf, &mut off) { + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + self.write_state = WriteState::WriteLen{ len, buf, off }; + } + return Err(e) + } + Ok(false) => { + trace!("write: eof"); + self.write_state = WriteState::Eof; + return Err(io::ErrorKind::WriteZero.into()) + } + Ok(true) => () } self.write_state = WriteState::WriteData { len, off: 0 } } @@ -290,19 +318,33 @@ impl io::Write for NoiseOutput { trace!("flush: encrypting {} bytes", off); if let Ok(n) = self.session.write_message(&buffer.write[.. off], buffer.write_crypto) { trace!("flush: cipher text len = {} bytes", n); - self.write_state = WriteState::WriteLen { len: n } + self.write_state = WriteState::WriteLen { + len: n, + buf: u16::to_be_bytes(n as u16), + off: 0 + } } else { debug!("encryption error"); self.write_state = WriteState::EncErr; return Err(io::ErrorKind::InvalidData.into()) } } - WriteState::WriteLen { len } => { - trace!("flush: writing len ({})", len); - if !write_frame_len(&mut self.io, len as u16)? { - trace!("write: eof"); - self.write_state = WriteState::Eof; - return Err(io::ErrorKind::WriteZero.into()) + WriteState::WriteLen { len, mut buf, mut off } => { + trace!("flush: writing len ({}, {:?}, {}/2)", len, buf, off); + match write_frame_len(&mut self.io, &mut buf, &mut off) { + Ok(true) => (), + Ok(false) => { + trace!("write: eof"); + self.write_state = WriteState::Eof; + return Err(io::ErrorKind::WriteZero.into()) + } + Err(e) => { + if e.kind() == io::ErrorKind::WouldBlock { + // Preserve write state + self.write_state = WriteState::WriteLen { len, buf, off }; + } + return Err(e) + } } self.write_state = WriteState::WriteData { len, off: 0 } } @@ -339,37 +381,45 @@ impl AsyncWrite for NoiseOutput { } } -/// Read 2 bytes as frame length. +/// Read 2 bytes as frame length from the given source into the given buffer. /// -/// Returns `None` if EOF has been encountered. -fn read_frame_len(io: &mut R) -> io::Result> { - let mut buf = [0, 0]; - let mut off = 0; +/// Panics if `off >= 2`. +/// +/// When [`io::ErrorKind::WouldBlock`] is returned, the given buffer and offset +/// may have been updated (i.e. a byte may have been read) and must be preserved +/// for the next invocation. +fn read_frame_len(io: &mut R, buf: &mut [u8; 2], off: &mut usize) + -> io::Result> +{ loop { - let n = io.read(&mut buf[off ..])?; + let n = io.read(&mut buf[*off ..])?; if n == 0 { return Ok(None) } - off += n; - if off == 2 { - return Ok(Some(u16::from_be_bytes(buf))) + *off += n; + if *off == 2 { + return Ok(Some(u16::from_be_bytes(*buf))) } } } -/// Write frame length. +/// Write 2 bytes as frame length from the given buffer into the given sink. /// -/// Returns `false` if EOF has been encountered. -fn write_frame_len(io: &mut W, len: u16) -> io::Result { - let buf = len.to_be_bytes(); - let mut off = 0; +/// Panics if `off >= 2`. +/// +/// When [`io::ErrorKind::WouldBlock`] is returned, the given offset +/// may have been updated (i.e. a byte may have been written) and must +/// be preserved for the next invocation. +fn write_frame_len(io: &mut W, buf: &[u8; 2], off: &mut usize) + -> io::Result +{ loop { - let n = io.write(&buf[off ..])?; + let n = io.write(&buf[*off ..])?; if n == 0 { return Ok(false) } - off += n; - if off == 2 { + *off += n; + if *off == 2 { return Ok(true) } }