Patch reading/writing frame lengths in libp2p-noise. (#1050)

* Patch reading/writing frame lengths in libp2p-noise.

Extracted from https://github.com/libp2p/rust-libp2p/pull/1027 since its
fate it still undetermined.

* Fix formatting.
This commit is contained in:
Roman Borschel 2019-04-10 17:54:24 +02:00 committed by GitHub
parent 6917b8f543
commit a266b1e724
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -94,6 +94,8 @@ impl<T: AsyncRead + AsyncWrite> Handshake<T> {
} }
/// A noise session to a remote. /// A noise session to a remote.
///
/// `T` is the type of the underlying I/O resource.
pub struct NoiseOutput<T> { pub struct NoiseOutput<T> {
io: T, io: T,
session: snow::Session, session: snow::Session,
@ -127,6 +129,8 @@ impl<T> NoiseOutput<T> {
enum ReadState { enum ReadState {
/// initial state /// initial state
Init, Init,
/// read frame length
ReadLen { buf: [u8; 2], off: usize },
/// read encrypted frame data /// read encrypted frame data
ReadData { len: usize, off: usize }, ReadData { len: usize, off: usize },
/// copy decrypted frame data /// copy decrypted frame data
@ -146,7 +150,7 @@ enum WriteState {
/// accumulate write data /// accumulate write data
BufferData { off: usize }, BufferData { off: usize },
/// write frame length /// write frame length
WriteLen { len: usize }, WriteLen { len: usize, buf: [u8; 2], off: usize },
/// write out encrypted data /// write out encrypted data
WriteData { len: usize, off: usize }, WriteData { len: usize, off: usize },
/// end of file has been reached (terminal state) /// end of file has been reached (terminal state)
@ -162,17 +166,28 @@ impl<T: io::Read> io::Read for NoiseOutput<T> {
trace!("read state: {:?}", self.read_state); trace!("read state: {:?}", self.read_state);
match self.read_state { match self.read_state {
ReadState::Init => { ReadState::Init => {
let n = match read_frame_len(&mut self.io)? { self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 };
Some(n) => n, }
None => { ReadState::ReadLen { mut buf, mut off } => {
let n = match read_frame_len(&mut self.io, &mut buf, &mut off) {
Ok(Some(n)) => n,
Ok(None) => {
trace!("read: eof"); trace!("read: eof");
self.read_state = ReadState::Eof(Ok(())); self.read_state = ReadState::Eof(Ok(()));
return Ok(0) return Ok(0)
} }
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
// Preserve read state
self.read_state = ReadState::ReadLen { buf, off };
}
return Err(e)
}
}; };
trace!("read: next frame len = {}", n); trace!("read: next frame len = {}", n);
if n == 0 { if n == 0 {
trace!("read: empty frame"); trace!("read: empty frame");
self.read_state = ReadState::Init;
continue continue
} }
self.read_state = ReadState::ReadData { len: usize::from(n), off: 0 } self.read_state = ReadState::ReadData { len: usize::from(n), off: 0 }
@ -204,7 +219,7 @@ impl<T: io::Read> io::Read for NoiseOutput<T> {
trace!("read: copied {}/{} bytes", *off + n, len); trace!("read: copied {}/{} bytes", *off + n, len);
*off += n; *off += n;
if len == *off { if len == *off {
self.read_state = ReadState::Init self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 };
} }
return Ok(n) return Ok(n)
} }
@ -240,7 +255,11 @@ impl<T: io::Write> io::Write for NoiseOutput<T> {
trace!("write: encrypting {} bytes", *off); trace!("write: encrypting {} bytes", *off);
if let Ok(n) = self.session.write_message(buffer.write, buffer.write_crypto) { if let Ok(n) = self.session.write_message(buffer.write, buffer.write_crypto) {
trace!("write: cipher text len = {} bytes", n); trace!("write: cipher text len = {} bytes", n);
self.write_state = WriteState::WriteLen { len: n } self.write_state = WriteState::WriteLen {
len: n,
buf: u16::to_be_bytes(n as u16),
off: 0
}
} else { } else {
debug!("encryption error"); debug!("encryption error");
self.write_state = WriteState::EncErr; self.write_state = WriteState::EncErr;
@ -249,12 +268,21 @@ impl<T: io::Write> io::Write for NoiseOutput<T> {
} }
return Ok(n) return Ok(n)
} }
WriteState::WriteLen { len } => { WriteState::WriteLen { len, mut buf, mut off } => {
trace!("write: writing len ({})", len); trace!("write: writing len ({}, {:?}, {}/2)", len, buf, off);
if !write_frame_len(&mut self.io, len as u16)? { match write_frame_len(&mut self.io, &mut buf, &mut off) {
trace!("write: eof"); Err(e) => {
self.write_state = WriteState::Eof; if e.kind() == io::ErrorKind::WouldBlock {
return Err(io::ErrorKind::WriteZero.into()) self.write_state = WriteState::WriteLen{ len, buf, off };
}
return Err(e)
}
Ok(false) => {
trace!("write: eof");
self.write_state = WriteState::Eof;
return Err(io::ErrorKind::WriteZero.into())
}
Ok(true) => ()
} }
self.write_state = WriteState::WriteData { len, off: 0 } self.write_state = WriteState::WriteData { len, off: 0 }
} }
@ -290,19 +318,33 @@ impl<T: io::Write> io::Write for NoiseOutput<T> {
trace!("flush: encrypting {} bytes", off); trace!("flush: encrypting {} bytes", off);
if let Ok(n) = self.session.write_message(&buffer.write[.. off], buffer.write_crypto) { if let Ok(n) = self.session.write_message(&buffer.write[.. off], buffer.write_crypto) {
trace!("flush: cipher text len = {} bytes", n); trace!("flush: cipher text len = {} bytes", n);
self.write_state = WriteState::WriteLen { len: n } self.write_state = WriteState::WriteLen {
len: n,
buf: u16::to_be_bytes(n as u16),
off: 0
}
} else { } else {
debug!("encryption error"); debug!("encryption error");
self.write_state = WriteState::EncErr; self.write_state = WriteState::EncErr;
return Err(io::ErrorKind::InvalidData.into()) return Err(io::ErrorKind::InvalidData.into())
} }
} }
WriteState::WriteLen { len } => { WriteState::WriteLen { len, mut buf, mut off } => {
trace!("flush: writing len ({})", len); trace!("flush: writing len ({}, {:?}, {}/2)", len, buf, off);
if !write_frame_len(&mut self.io, len as u16)? { match write_frame_len(&mut self.io, &mut buf, &mut off) {
trace!("write: eof"); Ok(true) => (),
self.write_state = WriteState::Eof; Ok(false) => {
return Err(io::ErrorKind::WriteZero.into()) trace!("write: eof");
self.write_state = WriteState::Eof;
return Err(io::ErrorKind::WriteZero.into())
}
Err(e) => {
if e.kind() == io::ErrorKind::WouldBlock {
// Preserve write state
self.write_state = WriteState::WriteLen { len, buf, off };
}
return Err(e)
}
} }
self.write_state = WriteState::WriteData { len, off: 0 } self.write_state = WriteState::WriteData { len, off: 0 }
} }
@ -339,37 +381,45 @@ impl<T: AsyncWrite> AsyncWrite for NoiseOutput<T> {
} }
} }
/// Read 2 bytes as frame length. /// Read 2 bytes as frame length from the given source into the given buffer.
/// ///
/// Returns `None` if EOF has been encountered. /// Panics if `off >= 2`.
fn read_frame_len<R: io::Read>(io: &mut R) -> io::Result<Option<u16>> { ///
let mut buf = [0, 0]; /// When [`io::ErrorKind::WouldBlock`] is returned, the given buffer and offset
let mut off = 0; /// may have been updated (i.e. a byte may have been read) and must be preserved
/// for the next invocation.
fn read_frame_len<R: io::Read>(io: &mut R, buf: &mut [u8; 2], off: &mut usize)
-> io::Result<Option<u16>>
{
loop { loop {
let n = io.read(&mut buf[off ..])?; let n = io.read(&mut buf[*off ..])?;
if n == 0 { if n == 0 {
return Ok(None) return Ok(None)
} }
off += n; *off += n;
if off == 2 { if *off == 2 {
return Ok(Some(u16::from_be_bytes(buf))) return Ok(Some(u16::from_be_bytes(*buf)))
} }
} }
} }
/// Write frame length. /// Write 2 bytes as frame length from the given buffer into the given sink.
/// ///
/// Returns `false` if EOF has been encountered. /// Panics if `off >= 2`.
fn write_frame_len<W: io::Write>(io: &mut W, len: u16) -> io::Result<bool> { ///
let buf = len.to_be_bytes(); /// When [`io::ErrorKind::WouldBlock`] is returned, the given offset
let mut off = 0; /// may have been updated (i.e. a byte may have been written) and must
/// be preserved for the next invocation.
fn write_frame_len<W: io::Write>(io: &mut W, buf: &[u8; 2], off: &mut usize)
-> io::Result<bool>
{
loop { loop {
let n = io.write(&buf[off ..])?; let n = io.write(&buf[*off ..])?;
if n == 0 { if n == 0 {
return Ok(false) return Ok(false)
} }
off += n; *off += n;
if off == 2 { if *off == 2 {
return Ok(true) return Ok(true)
} }
} }