mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-05-29 02:31:20 +00:00
Patch reading/writing frame lengths in libp2p-noise. (#1050)
* Patch reading/writing frame lengths in libp2p-noise. Extracted from https://github.com/libp2p/rust-libp2p/pull/1027 since its fate it still undetermined. * Fix formatting.
This commit is contained in:
parent
6917b8f543
commit
a266b1e724
@ -94,6 +94,8 @@ impl<T: AsyncRead + AsyncWrite> Handshake<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A noise session to a remote.
|
/// A noise session to a remote.
|
||||||
|
///
|
||||||
|
/// `T` is the type of the underlying I/O resource.
|
||||||
pub struct NoiseOutput<T> {
|
pub struct NoiseOutput<T> {
|
||||||
io: T,
|
io: T,
|
||||||
session: snow::Session,
|
session: snow::Session,
|
||||||
@ -127,6 +129,8 @@ impl<T> NoiseOutput<T> {
|
|||||||
enum ReadState {
|
enum ReadState {
|
||||||
/// initial state
|
/// initial state
|
||||||
Init,
|
Init,
|
||||||
|
/// read frame length
|
||||||
|
ReadLen { buf: [u8; 2], off: usize },
|
||||||
/// read encrypted frame data
|
/// read encrypted frame data
|
||||||
ReadData { len: usize, off: usize },
|
ReadData { len: usize, off: usize },
|
||||||
/// copy decrypted frame data
|
/// copy decrypted frame data
|
||||||
@ -146,7 +150,7 @@ enum WriteState {
|
|||||||
/// accumulate write data
|
/// accumulate write data
|
||||||
BufferData { off: usize },
|
BufferData { off: usize },
|
||||||
/// write frame length
|
/// write frame length
|
||||||
WriteLen { len: usize },
|
WriteLen { len: usize, buf: [u8; 2], off: usize },
|
||||||
/// write out encrypted data
|
/// write out encrypted data
|
||||||
WriteData { len: usize, off: usize },
|
WriteData { len: usize, off: usize },
|
||||||
/// end of file has been reached (terminal state)
|
/// end of file has been reached (terminal state)
|
||||||
@ -162,17 +166,28 @@ impl<T: io::Read> io::Read for NoiseOutput<T> {
|
|||||||
trace!("read state: {:?}", self.read_state);
|
trace!("read state: {:?}", self.read_state);
|
||||||
match self.read_state {
|
match self.read_state {
|
||||||
ReadState::Init => {
|
ReadState::Init => {
|
||||||
let n = match read_frame_len(&mut self.io)? {
|
self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 };
|
||||||
Some(n) => n,
|
}
|
||||||
None => {
|
ReadState::ReadLen { mut buf, mut off } => {
|
||||||
|
let n = match read_frame_len(&mut self.io, &mut buf, &mut off) {
|
||||||
|
Ok(Some(n)) => n,
|
||||||
|
Ok(None) => {
|
||||||
trace!("read: eof");
|
trace!("read: eof");
|
||||||
self.read_state = ReadState::Eof(Ok(()));
|
self.read_state = ReadState::Eof(Ok(()));
|
||||||
return Ok(0)
|
return Ok(0)
|
||||||
}
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if e.kind() == io::ErrorKind::WouldBlock {
|
||||||
|
// Preserve read state
|
||||||
|
self.read_state = ReadState::ReadLen { buf, off };
|
||||||
|
}
|
||||||
|
return Err(e)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
trace!("read: next frame len = {}", n);
|
trace!("read: next frame len = {}", n);
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
trace!("read: empty frame");
|
trace!("read: empty frame");
|
||||||
|
self.read_state = ReadState::Init;
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
self.read_state = ReadState::ReadData { len: usize::from(n), off: 0 }
|
self.read_state = ReadState::ReadData { len: usize::from(n), off: 0 }
|
||||||
@ -204,7 +219,7 @@ impl<T: io::Read> io::Read for NoiseOutput<T> {
|
|||||||
trace!("read: copied {}/{} bytes", *off + n, len);
|
trace!("read: copied {}/{} bytes", *off + n, len);
|
||||||
*off += n;
|
*off += n;
|
||||||
if len == *off {
|
if len == *off {
|
||||||
self.read_state = ReadState::Init
|
self.read_state = ReadState::ReadLen { buf: [0, 0], off: 0 };
|
||||||
}
|
}
|
||||||
return Ok(n)
|
return Ok(n)
|
||||||
}
|
}
|
||||||
@ -240,7 +255,11 @@ impl<T: io::Write> io::Write for NoiseOutput<T> {
|
|||||||
trace!("write: encrypting {} bytes", *off);
|
trace!("write: encrypting {} bytes", *off);
|
||||||
if let Ok(n) = self.session.write_message(buffer.write, buffer.write_crypto) {
|
if let Ok(n) = self.session.write_message(buffer.write, buffer.write_crypto) {
|
||||||
trace!("write: cipher text len = {} bytes", n);
|
trace!("write: cipher text len = {} bytes", n);
|
||||||
self.write_state = WriteState::WriteLen { len: n }
|
self.write_state = WriteState::WriteLen {
|
||||||
|
len: n,
|
||||||
|
buf: u16::to_be_bytes(n as u16),
|
||||||
|
off: 0
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
debug!("encryption error");
|
debug!("encryption error");
|
||||||
self.write_state = WriteState::EncErr;
|
self.write_state = WriteState::EncErr;
|
||||||
@ -249,12 +268,21 @@ impl<T: io::Write> io::Write for NoiseOutput<T> {
|
|||||||
}
|
}
|
||||||
return Ok(n)
|
return Ok(n)
|
||||||
}
|
}
|
||||||
WriteState::WriteLen { len } => {
|
WriteState::WriteLen { len, mut buf, mut off } => {
|
||||||
trace!("write: writing len ({})", len);
|
trace!("write: writing len ({}, {:?}, {}/2)", len, buf, off);
|
||||||
if !write_frame_len(&mut self.io, len as u16)? {
|
match write_frame_len(&mut self.io, &mut buf, &mut off) {
|
||||||
trace!("write: eof");
|
Err(e) => {
|
||||||
self.write_state = WriteState::Eof;
|
if e.kind() == io::ErrorKind::WouldBlock {
|
||||||
return Err(io::ErrorKind::WriteZero.into())
|
self.write_state = WriteState::WriteLen{ len, buf, off };
|
||||||
|
}
|
||||||
|
return Err(e)
|
||||||
|
}
|
||||||
|
Ok(false) => {
|
||||||
|
trace!("write: eof");
|
||||||
|
self.write_state = WriteState::Eof;
|
||||||
|
return Err(io::ErrorKind::WriteZero.into())
|
||||||
|
}
|
||||||
|
Ok(true) => ()
|
||||||
}
|
}
|
||||||
self.write_state = WriteState::WriteData { len, off: 0 }
|
self.write_state = WriteState::WriteData { len, off: 0 }
|
||||||
}
|
}
|
||||||
@ -290,19 +318,33 @@ impl<T: io::Write> io::Write for NoiseOutput<T> {
|
|||||||
trace!("flush: encrypting {} bytes", off);
|
trace!("flush: encrypting {} bytes", off);
|
||||||
if let Ok(n) = self.session.write_message(&buffer.write[.. off], buffer.write_crypto) {
|
if let Ok(n) = self.session.write_message(&buffer.write[.. off], buffer.write_crypto) {
|
||||||
trace!("flush: cipher text len = {} bytes", n);
|
trace!("flush: cipher text len = {} bytes", n);
|
||||||
self.write_state = WriteState::WriteLen { len: n }
|
self.write_state = WriteState::WriteLen {
|
||||||
|
len: n,
|
||||||
|
buf: u16::to_be_bytes(n as u16),
|
||||||
|
off: 0
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
debug!("encryption error");
|
debug!("encryption error");
|
||||||
self.write_state = WriteState::EncErr;
|
self.write_state = WriteState::EncErr;
|
||||||
return Err(io::ErrorKind::InvalidData.into())
|
return Err(io::ErrorKind::InvalidData.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
WriteState::WriteLen { len } => {
|
WriteState::WriteLen { len, mut buf, mut off } => {
|
||||||
trace!("flush: writing len ({})", len);
|
trace!("flush: writing len ({}, {:?}, {}/2)", len, buf, off);
|
||||||
if !write_frame_len(&mut self.io, len as u16)? {
|
match write_frame_len(&mut self.io, &mut buf, &mut off) {
|
||||||
trace!("write: eof");
|
Ok(true) => (),
|
||||||
self.write_state = WriteState::Eof;
|
Ok(false) => {
|
||||||
return Err(io::ErrorKind::WriteZero.into())
|
trace!("write: eof");
|
||||||
|
self.write_state = WriteState::Eof;
|
||||||
|
return Err(io::ErrorKind::WriteZero.into())
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
if e.kind() == io::ErrorKind::WouldBlock {
|
||||||
|
// Preserve write state
|
||||||
|
self.write_state = WriteState::WriteLen { len, buf, off };
|
||||||
|
}
|
||||||
|
return Err(e)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
self.write_state = WriteState::WriteData { len, off: 0 }
|
self.write_state = WriteState::WriteData { len, off: 0 }
|
||||||
}
|
}
|
||||||
@ -339,37 +381,45 @@ impl<T: AsyncWrite> AsyncWrite for NoiseOutput<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Read 2 bytes as frame length.
|
/// Read 2 bytes as frame length from the given source into the given buffer.
|
||||||
///
|
///
|
||||||
/// Returns `None` if EOF has been encountered.
|
/// Panics if `off >= 2`.
|
||||||
fn read_frame_len<R: io::Read>(io: &mut R) -> io::Result<Option<u16>> {
|
///
|
||||||
let mut buf = [0, 0];
|
/// When [`io::ErrorKind::WouldBlock`] is returned, the given buffer and offset
|
||||||
let mut off = 0;
|
/// may have been updated (i.e. a byte may have been read) and must be preserved
|
||||||
|
/// for the next invocation.
|
||||||
|
fn read_frame_len<R: io::Read>(io: &mut R, buf: &mut [u8; 2], off: &mut usize)
|
||||||
|
-> io::Result<Option<u16>>
|
||||||
|
{
|
||||||
loop {
|
loop {
|
||||||
let n = io.read(&mut buf[off ..])?;
|
let n = io.read(&mut buf[*off ..])?;
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return Ok(None)
|
return Ok(None)
|
||||||
}
|
}
|
||||||
off += n;
|
*off += n;
|
||||||
if off == 2 {
|
if *off == 2 {
|
||||||
return Ok(Some(u16::from_be_bytes(buf)))
|
return Ok(Some(u16::from_be_bytes(*buf)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Write frame length.
|
/// Write 2 bytes as frame length from the given buffer into the given sink.
|
||||||
///
|
///
|
||||||
/// Returns `false` if EOF has been encountered.
|
/// Panics if `off >= 2`.
|
||||||
fn write_frame_len<W: io::Write>(io: &mut W, len: u16) -> io::Result<bool> {
|
///
|
||||||
let buf = len.to_be_bytes();
|
/// When [`io::ErrorKind::WouldBlock`] is returned, the given offset
|
||||||
let mut off = 0;
|
/// may have been updated (i.e. a byte may have been written) and must
|
||||||
|
/// be preserved for the next invocation.
|
||||||
|
fn write_frame_len<W: io::Write>(io: &mut W, buf: &[u8; 2], off: &mut usize)
|
||||||
|
-> io::Result<bool>
|
||||||
|
{
|
||||||
loop {
|
loop {
|
||||||
let n = io.write(&buf[off ..])?;
|
let n = io.write(&buf[*off ..])?;
|
||||||
if n == 0 {
|
if n == 0 {
|
||||||
return Ok(false)
|
return Ok(false)
|
||||||
}
|
}
|
||||||
off += n;
|
*off += n;
|
||||||
if off == 2 {
|
if *off == 2 {
|
||||||
return Ok(true)
|
return Ok(true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user