Grow noise buffers dynamically. (#1436)

* Grow noise buffers dynamically.

Currently we allocate a buffer of 176 KiB for each noise state, i.e.
each connection. For connections which see only small data frames
this is wasteful. At the same time we limit the max. write buffer size
to 16 KiB to keep the total buffer size relatively small, which
results in smaller encrypted messages and also makes it less likely to
ever encounter the max. noise package size of 64 KiB in practice when
communicating with other nodes using the same implementation.

This PR repaces the static buffer allocation with a dynamic one. We
only reserve a small space for the authentication tag plus some extra
reserve and are able to buffer larger data frames before encrypting.

* Grow write buffer from offset.

As suggested by @mxinden, this prevents increasing the write buffer up
to MAX_WRITE_BUF_LEN.

Co-authored-by: Pierre Krieger <pierre.krieger1708@gmail.com>
This commit is contained in:
Toralf Wittner
2020-02-13 12:38:33 +01:00
committed by GitHub
parent bbed28b3ec
commit 70d634daff
3 changed files with 106 additions and 106 deletions

View File

@ -16,6 +16,7 @@ log = "0.4"
prost = "0.6.1"
rand = "0.7.2"
sha2 = "0.8.0"
static_assertions = "1"
x25519-dalek = "0.5"
zeroize = "1"
@ -25,7 +26,6 @@ snow = { version = "0.6.1", features = ["ring-resolver"], default-features = fal
[target.'cfg(target_os = "unknown")'.dependencies]
snow = { version = "0.6.1", features = ["default-resolver"], default-features = false }
[dev-dependencies]
env_logger = "0.7.1"
libp2p-tcp = { version = "0.15.0", path = "../../transports/tcp" }

View File

@ -26,33 +26,17 @@ use futures::ready;
use futures::prelude::*;
use log::{debug, trace};
use snow;
use std::{fmt, io, pin::Pin, ops::DerefMut, task::{Context, Poll}};
use std::{cmp::min, fmt, io, pin::Pin, ops::DerefMut, task::{Context, Poll}};
/// Max. size of a noise package.
const MAX_NOISE_PKG_LEN: usize = 65535;
const MAX_WRITE_BUF_LEN: usize = 16384;
const TOTAL_BUFFER_LEN: usize = 2 * MAX_NOISE_PKG_LEN + 3 * MAX_WRITE_BUF_LEN;
/// Extra space given to the encryption buffer to hold key material.
const EXTRA_ENCRYPT_SPACE: usize = 1024;
/// Max. output buffer size before forcing a flush.
const MAX_WRITE_BUF_LEN: usize = MAX_NOISE_PKG_LEN - EXTRA_ENCRYPT_SPACE;
/// A single `Buffer` contains multiple non-overlapping byte buffers.
struct Buffer {
inner: Box<[u8; TOTAL_BUFFER_LEN]>
}
/// A mutable borrow of all byte buffers, backed by `Buffer`.
struct BufferBorrow<'a> {
read: &'a mut [u8],
read_crypto: &'a mut [u8],
write: &'a mut [u8],
write_crypto: &'a mut [u8]
}
impl Buffer {
/// Create a mutable borrow by splitting the buffer slice.
fn borrow_mut(&mut self) -> BufferBorrow<'_> {
let (r, w) = self.inner.split_at_mut(2 * MAX_NOISE_PKG_LEN);
let (read, read_crypto) = r.split_at_mut(MAX_NOISE_PKG_LEN);
let (write, write_crypto) = w.split_at_mut(MAX_WRITE_BUF_LEN);
BufferBorrow { read, read_crypto, write, write_crypto }
}
static_assertions::const_assert! {
MAX_WRITE_BUF_LEN + EXTRA_ENCRYPT_SPACE <= MAX_NOISE_PKG_LEN
}
/// A passthrough enum for the two kinds of state machines in `snow`
@ -97,9 +81,12 @@ impl SnowState {
pub struct NoiseOutput<T> {
io: T,
session: SnowState,
buffer: Buffer,
read_state: ReadState,
write_state: WriteState
write_state: WriteState,
read_buffer: Vec<u8>,
write_buffer: Vec<u8>,
decrypt_buffer: Vec<u8>,
encrypt_buffer: Vec<u8>
}
impl<T> fmt::Debug for NoiseOutput<T> {
@ -116,9 +103,12 @@ impl<T> NoiseOutput<T> {
NoiseOutput {
io,
session,
buffer: Buffer { inner: Box::new([0; TOTAL_BUFFER_LEN]) },
read_state: ReadState::Init,
write_state: WriteState::Init
write_state: WriteState::Init,
read_buffer: Vec::new(),
write_buffer: Vec::new(),
decrypt_buffer: Vec::new(),
encrypt_buffer: Vec::new()
}
}
}
@ -159,15 +149,8 @@ enum WriteState {
}
impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, std::io::Error>> {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll<io::Result<usize>> {
let mut this = self.deref_mut();
let buffer = this.buffer.borrow_mut();
loop {
trace!("read state: {:?}", this.read_state);
match this.read_state {
@ -187,7 +170,6 @@ impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
}
Poll::Pending => {
this.read_state = ReadState::ReadLen { buf, off };
return Poll::Pending;
}
};
@ -197,30 +179,28 @@ impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
this.read_state = ReadState::Init;
continue
}
this.read_buffer.resize(usize::from(n), 0u8);
this.read_state = ReadState::ReadData { len: usize::from(n), off: 0 }
}
ReadState::ReadData { len, ref mut off } => {
let n = match ready!(
Pin::new(&mut this.io).poll_read(cx, &mut buffer.read[*off ..len])
) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
let n = {
let f = Pin::new(&mut this.io).poll_read(cx, &mut this.read_buffer[*off .. len]);
match ready!(f) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
}
};
trace!("read: read {}/{} bytes", *off + n, len);
if n == 0 {
trace!("read: eof");
this.read_state = ReadState::Eof(Err(()));
return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into()))
}
*off += n;
if len == *off {
trace!("read: decrypting {} bytes", len);
if let Ok(n) = this.session.read_message(
&buffer.read[.. len],
buffer.read_crypto
){
this.decrypt_buffer.resize(len, 0u8);
if let Ok(n) = this.session.read_message(&this.read_buffer, &mut this.decrypt_buffer) {
trace!("read: payload len = {} bytes", n);
this.read_state = ReadState::CopyData { len: n, off: 0 }
} else {
@ -231,8 +211,8 @@ impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
}
}
ReadState::CopyData { len, ref mut off } => {
let n = std::cmp::min(len - *off, buf.len());
buf[.. n].copy_from_slice(&buffer.read_crypto[*off .. *off + n]);
let n = min(len - *off, buf.len());
buf[.. n].copy_from_slice(&this.decrypt_buffer[*off .. *off + n]);
trace!("read: copied {}/{} bytes", *off + n, len);
*off += n;
if len == *off {
@ -255,15 +235,8 @@ impl<T: AsyncRead + Unpin> AsyncRead for NoiseOutput<T> {
}
impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>>{
fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
let mut this = self.deref_mut();
let buffer = this.buffer.borrow_mut();
loop {
trace!("write state: {:?}", this.write_state);
match this.write_state {
@ -271,13 +244,16 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
this.write_state = WriteState::BufferData { off: 0 }
}
WriteState::BufferData { ref mut off } => {
let n = std::cmp::min(MAX_WRITE_BUF_LEN - *off, buf.len());
buffer.write[*off .. *off + n].copy_from_slice(&buf[.. n]);
let n = min(MAX_WRITE_BUF_LEN, off.saturating_add(buf.len()));
this.write_buffer.resize(n, 0u8);
let n = min(MAX_WRITE_BUF_LEN - *off, buf.len());
this.write_buffer[*off .. *off + n].copy_from_slice(&buf[.. n]);
trace!("write: buffered {} bytes", *off + n);
*off += n;
if *off == MAX_WRITE_BUF_LEN {
trace!("write: encrypting {} bytes", *off);
match this.session.write_message(buffer.write, buffer.write_crypto) {
this.encrypt_buffer.resize(MAX_WRITE_BUF_LEN + EXTRA_ENCRYPT_SPACE, 0u8);
match this.session.write_message(&this.write_buffer, &mut this.encrypt_buffer) {
Ok(n) => {
trace!("write: cipher text len = {} bytes", n);
this.write_state = WriteState::WriteLen {
@ -316,11 +292,12 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
this.write_state = WriteState::WriteData { len, off: 0 }
}
WriteState::WriteData { len, ref mut off } => {
let n = match ready!(
Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len])
) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
let n = {
let f = Pin::new(&mut this.io).poll_write(cx, &this.encrypt_buffer[*off .. len]);
match ready!(f) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e))
}
};
trace!("write: wrote {}/{} bytes", *off + n, len);
if n == 0 {
@ -343,20 +320,17 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>
) -> Poll<Result<(), std::io::Error>> {
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
let mut this = self.deref_mut();
let buffer = this.buffer.borrow_mut();
loop {
match this.write_state {
WriteState::Init => return Pin::new(&mut this.io).poll_flush(cx),
WriteState::Init => {
return Pin::new(&mut this.io).poll_flush(cx)
}
WriteState::BufferData { off } => {
trace!("flush: encrypting {} bytes", off);
match this.session.write_message(&buffer.write[.. off], buffer.write_crypto) {
this.encrypt_buffer.resize(off + EXTRA_ENCRYPT_SPACE, 0u8);
match this.session.write_message(&this.write_buffer[.. off], &mut this.encrypt_buffer) {
Ok(n) => {
trace!("flush: cipher text len = {} bytes", n);
this.write_state = WriteState::WriteLen {
@ -386,18 +360,18 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
}
Poll::Pending => {
this.write_state = WriteState::WriteLen { len, buf, off };
return Poll::Pending
}
}
this.write_state = WriteState::WriteData { len, off: 0 }
}
WriteState::WriteData { len, ref mut off } => {
let n = match ready!(
Pin::new(&mut this.io).poll_write(cx, &buffer.write_crypto[*off .. len])
) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
let n = {
let f = Pin::new(&mut this.io).poll_write(cx, &this.encrypt_buffer[*off .. len]);
match ready!(f) {
Ok(n) => n,
Err(e) => return Poll::Ready(Err(e)),
}
};
trace!("flush: wrote {}/{} bytes", *off + n, len);
if n == 0 {
@ -420,10 +394,7 @@ impl<T: AsyncWrite + Unpin> AsyncWrite for NoiseOutput<T> {
}
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), std::io::Error>>{
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>>{
ready!(self.as_mut().poll_flush(cx))?;
Pin::new(&mut self.io).poll_close(cx)
}
@ -443,7 +414,7 @@ fn read_frame_len<R: AsyncRead + Unpin>(
cx: &mut Context<'_>,
buf: &mut [u8; 2],
off: &mut usize,
) -> Poll<Result<Option<u16>, std::io::Error>> {
) -> Poll<io::Result<Option<u16>>> {
loop {
match ready!(Pin::new(&mut io).poll_read(cx, &mut buf[*off ..])) {
Ok(n) => {
@ -476,7 +447,7 @@ fn write_frame_len<W: AsyncWrite + Unpin>(
cx: &mut Context<'_>,
buf: &[u8; 2],
off: &mut usize,
) -> Poll<Result<bool, std::io::Error>> {
) -> Poll<io::Result<bool>> {
loop {
match ready!(Pin::new(&mut io).poll_write(cx, &buf[*off ..])) {
Ok(n) => {

View File

@ -26,6 +26,7 @@ use libp2p_noise::{Keypair, X25519, NoiseConfig, RemoteIdentity, NoiseError, Noi
use libp2p_tcp::{TcpConfig, TcpTransStream};
use log::info;
use quickcheck::QuickCheck;
use std::{convert::TryInto, io};
#[allow(dead_code)]
fn core_upgrade_compat() {
@ -40,7 +41,8 @@ fn core_upgrade_compat() {
#[test]
fn xx() {
let _ = env_logger::try_init();
fn prop(message: Vec<u8>) -> bool {
fn prop(mut messages: Vec<Message>) -> bool {
messages.truncate(5);
let server_id = identity::Keypair::generate_ed25519();
let client_id = identity::Keypair::generate_ed25519();
@ -61,16 +63,17 @@ fn xx() {
})
.and_then(move |out, _| expect_identity(out, &server_id_public));
run(server_transport, client_transport, message);
run(server_transport, client_transport, messages);
true
}
QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec<u8>) -> bool)
QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec<Message>) -> bool)
}
#[test]
fn ix() {
let _ = env_logger::try_init();
fn prop(message: Vec<u8>) -> bool {
fn prop(mut messages: Vec<Message>) -> bool {
messages.truncate(5);
let server_id = identity::Keypair::generate_ed25519();
let client_id = identity::Keypair::generate_ed25519();
@ -91,16 +94,17 @@ fn ix() {
})
.and_then(move |out, _| expect_identity(out, &server_id_public));
run(server_transport, client_transport, message);
run(server_transport, client_transport, messages);
true
}
QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec<u8>) -> bool)
QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec<Message>) -> bool)
}
#[test]
fn ik_xx() {
let _ = env_logger::try_init();
fn prop(message: Vec<u8>) -> bool {
fn prop(mut messages: Vec<Message>) -> bool {
messages.truncate(5);
let server_id = identity::Keypair::generate_ed25519();
let server_id_public = server_id.public();
@ -134,15 +138,15 @@ fn ik_xx() {
})
.and_then(move |out, _| expect_identity(out, &server_id_public2));
run(server_transport, client_transport, message);
run(server_transport, client_transport, messages);
true
}
QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec<u8>) -> bool)
QuickCheck::new().max_tests(30).quickcheck(prop as fn(Vec<Message>) -> bool)
}
type Output = (RemoteIdentity<X25519>, NoiseOutput<Negotiated<TcpTransStream>>);
fn run<T, U>(server_transport: T, client_transport: U, message1: Vec<u8>)
fn run<T, U, I>(server_transport: T, client_transport: U, messages: I)
where
T: Transport<Output = Output>,
T::Dial: Send + 'static,
@ -152,10 +156,9 @@ where
U::Dial: Send + 'static,
U::Listener: Send + 'static,
U::ListenerUpgrade: Send + 'static,
I: IntoIterator<Item = Message> + Clone
{
futures::executor::block_on(async {
let mut message2 = message1.clone();
let mut server: T::Listener = server_transport
.listen_on("/ip4/127.0.0.1/tcp/0".parse().unwrap())
.unwrap();
@ -167,6 +170,7 @@ where
.into_new_address()
.expect("listen address");
let outbound_msgs = messages.clone();
let client_fut = async {
let mut client_session = client_transport.dial(server_address.clone())
.unwrap()
@ -174,7 +178,11 @@ where
.map(|(_, session)| session)
.expect("no error");
client_session.write_all(&mut message2).await.expect("no error");
for m in outbound_msgs {
let n = (m.0.len() as u64).to_be_bytes();
client_session.write_all(&n[..]).await.expect("len written");
client_session.write_all(&m.0).await.expect("no error")
}
client_session.flush().await.expect("no error");
};
@ -190,11 +198,20 @@ where
.map(|(_, session)| session)
.expect("no error");
let mut server_buffer = vec![];
info!("server: reading message");
server_session.read_to_end(&mut server_buffer).await.expect("no error");
assert_eq!(server_buffer, message1);
for m in messages {
let len = {
let mut n = [0; 8];
match server_session.read_exact(&mut n).await {
Ok(()) => u64::from_be_bytes(n),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => 0,
Err(e) => panic!("error reading len: {}", e)
}
};
info!("server: reading message ({} bytes)", len);
let mut server_buffer = vec![0; len.try_into().unwrap()];
server_session.read_exact(&mut server_buffer).await.expect("no error");
assert_eq!(server_buffer, m.0)
}
};
futures::future::join(server_fut, client_fut).await;
@ -209,3 +226,15 @@ fn expect_identity(output: Output, pk: &identity::PublicKey)
_ => panic!("Unexpected remote identity")
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct Message(Vec<u8>);
impl quickcheck::Arbitrary for Message {
fn arbitrary<G: quickcheck::Gen>(g: &mut G) -> Self {
let s = 1 + g.next_u32() % (128 * 1024);
let mut v = vec![0; s.try_into().unwrap()];
g.fill_bytes(&mut v);
Message(v)
}
}