Fixes to the mplex implementation (#360)

* Fixes to the mplex implementation
* Fix mem leak and wrong logging message
* Correctly handle Close and Reset
* Check the even-ness of the substream id
This commit is contained in:
Pierre Krieger
2018-08-13 11:29:07 +02:00
committed by Benjamin Kampmann
parent 73996885cb
commit b673209839
3 changed files with 110 additions and 45 deletions

View File

@ -23,6 +23,8 @@ extern crate fnv;
#[macro_use]
extern crate futures;
extern crate libp2p_core as core;
#[macro_use]
extern crate log;
extern crate parking_lot;
extern crate tokio_codec;
extern crate tokio_io;
@ -32,12 +34,11 @@ mod codec;
use std::{cmp, iter};
use std::io::{Read, Write, Error as IoError, ErrorKind as IoErrorKind};
use std::mem;
use std::sync::Arc;
use std::sync::{Arc, atomic::AtomicUsize, atomic::Ordering};
use bytes::Bytes;
use core::{ConnectionUpgrade, Endpoint, StreamMuxer};
use parking_lot::Mutex;
use fnv::FnvHashSet;
use fnv::{FnvHashMap, FnvHashSet};
use futures::prelude::*;
use futures::{future, stream::Fuse, task};
use tokio_codec::Framed;
@ -46,7 +47,7 @@ use tokio_io::{AsyncRead, AsyncWrite};
// Maximum number of simultaneously-open substreams.
const MAX_SUBSTREAMS: usize = 1024;
// Maximum number of elements in the internal buffer.
const MAX_BUFFER_LEN: usize = 256;
const MAX_BUFFER_LEN: usize = 1024;
/// Configuration for the multiplexer.
#[derive(Debug, Clone, Default)]
@ -74,11 +75,12 @@ where
fn upgrade(self, i: C, _: (), endpoint: Endpoint, remote_addr: Maf) -> Self::Future {
let out = Multiplex {
inner: Arc::new(Mutex::new(MultiplexInner {
error: Ok(()),
inner: Framed::new(i, codec::Codec::new()).fuse(),
buffer: Vec::with_capacity(32),
opened_substreams: Default::default(),
next_outbound_stream_id: if endpoint == Endpoint::Dialer { 0 } else { 1 },
to_notify: Vec::new(),
to_notify: Default::default(),
}))
};
@ -107,27 +109,34 @@ impl<C> Clone for Multiplex<C> {
// Struct shared throughout the implementation.
struct MultiplexInner<C> {
// Errored that happend earlier. Should poison any attempt to use this `MultiplexError`.
error: Result<(), IoError>,
// Underlying stream.
inner: Fuse<Framed<C, codec::Codec>>,
// Buffer of elements pulled from the stream but not processed yet.
buffer: Vec<codec::Elem>,
// List of Ids of opened substreams. Used to filter out messages that don't belong to any
// substream.
// substream. Note that this is handled exclusively by `next_match`.
opened_substreams: FnvHashSet<u32>,
// Id of the next outgoing substream. Should always increase by two.
next_outbound_stream_id: u32,
// List of tasks to notify when a new element is inserted in `buffer`.
to_notify: Vec<task::Task>,
// List of tasks to notify when a new element is inserted in `buffer` or an error happens.
to_notify: FnvHashMap<usize, task::Task>,
}
// Processes elements in `inner` until one matching `filter` is found.
//
// If `NotReady` is returned, the current task is scheduled for later, just like with any `Poll`.
// `Ready(Some())` is almost always returned. `Ready(None)` is returned if the stream is EOF.
/// Processes elements in `inner` until one matching `filter` is found.
///
/// If `NotReady` is returned, the current task is scheduled for later, just like with any `Poll`.
/// `Ready(Some())` is almost always returned. `Ready(None)` is returned if the stream is EOF.
fn next_match<C, F, O>(inner: &mut MultiplexInner<C>, mut filter: F) -> Poll<Option<O>, IoError>
where C: AsyncRead + AsyncWrite,
F: FnMut(&codec::Elem) -> Option<O>,
{
// If an error happened earlier, immediately return it.
if let Err(ref err) = inner.error {
return Err(IoError::new(err.kind(), err.to_string()));
}
if let Some((offset, out)) = inner.buffer.iter().enumerate().filter_map(|(n, v)| filter(v).map(|v| (n, v))).next() {
inner.buffer.remove(offset);
return Ok(Async::Ready(Some(out)));
@ -137,27 +146,66 @@ where C: AsyncRead + AsyncWrite,
let elem = match inner.inner.poll() {
Ok(Async::Ready(item)) => item,
Ok(Async::NotReady) => {
inner.to_notify.push(task::current());
static NEXT_TASK_ID: AtomicUsize = AtomicUsize::new(0);
task_local!{
static TASK_ID: usize = NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed)
}
inner.to_notify.insert(TASK_ID.with(|&t| t), task::current());
return Ok(Async::NotReady);
},
Err(err) => {
return Err(err);
let err2 = IoError::new(err.kind(), err.to_string());
inner.error = Err(err);
for task in inner.to_notify.drain() {
task.1.notify();
}
return Err(err2);
},
};
if let Some(elem) = elem {
trace!("Received message: {:?}", elem);
// Handle substreams opening/closing.
match elem {
codec::Elem::Open { substream_id } => {
if (substream_id % 2) == (inner.next_outbound_stream_id % 2) {
inner.error = Err(IoError::new(IoErrorKind::Other, "invalid substream id opened"));
for task in inner.to_notify.drain() {
task.1.notify();
}
return Err(IoError::new(IoErrorKind::Other, "invalid substream id opened"));
}
if !inner.opened_substreams.insert(substream_id) {
debug!("Received open message for substream {} which was already open", substream_id)
}
},
codec::Elem::Close { substream_id, .. } | codec::Elem::Reset { substream_id, .. } => {
inner.opened_substreams.remove(&substream_id);
},
_ => ()
}
if let Some(out) = filter(&elem) {
return Ok(Async::Ready(Some(out)));
} else {
if inner.buffer.len() >= MAX_BUFFER_LEN {
return Err(IoError::new(IoErrorKind::InvalidData, "reached maximum buffer length"));
debug!("Reached mplex maximum buffer length");
inner.error = Err(IoError::new(IoErrorKind::Other, "reached maximum buffer length"));
for task in inner.to_notify.drain() {
task.1.notify();
}
return Err(IoError::new(IoErrorKind::Other, "reached maximum buffer length"));
}
if inner.opened_substreams.contains(&elem.substream_id()) || elem.is_open_msg() {
inner.buffer.push(elem);
for task in inner.to_notify.drain(..) {
task.notify();
for task in inner.to_notify.drain() {
task.1.notify();
}
} else if !elem.is_close_or_reset_msg() {
debug!("Ignored message {:?} because the substream wasn't open", elem);
}
}
} else {
@ -166,13 +214,6 @@ where C: AsyncRead + AsyncWrite,
}
}
// Closes a substream in `inner`.
fn clean_out_substream<C>(inner: &mut MultiplexInner<C>, num: u32) {
let was_in = inner.opened_substreams.remove(&num);
debug_assert!(was_in, "Dropped substream which wasn't open ; programmer error");
inner.buffer.retain(|elem| elem.substream_id() != num);
}
// Small convenience function that tries to write `elem` to the stream.
fn poll_send<C>(inner: &mut MultiplexInner<C>, elem: codec::Elem) -> Poll<(), IoError>
where C: AsyncRead + AsyncWrite
@ -212,12 +253,17 @@ where C: AsyncRead + AsyncWrite + 'static // TODO: 'static :-/
};
// We use an RAII guard, so that we close the substream in case of an error.
struct OpenedSubstreamGuard<C>(Arc<Mutex<MultiplexInner<C>>>, u32);
struct OpenedSubstreamGuard<C>(Option<Arc<Mutex<MultiplexInner<C>>>>, u32);
impl<C> Drop for OpenedSubstreamGuard<C> {
fn drop(&mut self) { clean_out_substream(&mut self.0.lock(), self.1); }
fn drop(&mut self) {
if let Some(inner) = self.0.take() {
debug!("Failed to open outbound substream {}", self.1);
inner.lock().buffer.retain(|elem| elem.substream_id() != self.1);
}
}
}
inner.opened_substreams.insert(substream_id);
let guard = OpenedSubstreamGuard(self.inner.clone(), substream_id);
let mut guard = OpenedSubstreamGuard(Some(self.inner.clone()), substream_id);
// We send `Open { substream_id }`, then flush, then only produce the substream.
let future = {
@ -232,17 +278,14 @@ where C: AsyncRead + AsyncWrite + 'static // TODO: 'static :-/
move |()| {
future::poll_fn(move || inner.lock().inner.poll_complete())
}
}).map({
let inner = self.inner.clone();
move |()| {
mem::forget(guard);
Some(Substream {
inner: inner.clone(),
num: substream_id,
current_data: Bytes::new(),
endpoint: Endpoint::Dialer,
})
}
}).map(move |()| {
debug!("Successfully opened outbound substream {}", substream_id);
Some(Substream {
inner: guard.0.take().unwrap(),
num: substream_id,
current_data: Bytes::new(),
endpoint: Endpoint::Dialer,
})
})
};
@ -265,19 +308,20 @@ where C: AsyncRead + AsyncWrite
let mut inner = self.inner.lock();
if inner.opened_substreams.len() >= MAX_SUBSTREAMS {
debug!("Refused substream ; reached maximum number of substreams {}", MAX_SUBSTREAMS);
return Err(IoError::new(IoErrorKind::ConnectionRefused,
"exceeded maximum number of open substreams"));
}
let num = try_ready!(next_match(&mut inner, |elem| {
match elem {
codec::Elem::Open { substream_id } => Some(*substream_id), // TODO: check even/uneven?
codec::Elem::Open { substream_id } => Some(*substream_id),
_ => None,
}
}));
if let Some(num) = num {
inner.opened_substreams.insert(num);
debug!("Successfully opened inbound substream {}", num);
Ok(Async::Ready(Some(Substream {
inner: self.inner.clone(),
current_data: Bytes::new(),
@ -306,13 +350,14 @@ where C: AsyncRead + AsyncWrite
{
fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
loop {
// First transfer from `current_data`.
// First, transfer from `current_data`.
if self.current_data.len() != 0 {
let len = cmp::min(self.current_data.len(), buf.len());
buf[..len].copy_from_slice(&self.current_data.split_to(len));
return Ok(len);
}
// Try to find a packet of data in the buffer.
let mut inner = self.inner.lock();
let next_data_poll = next_match(&mut inner, |elem| {
match elem {
@ -328,7 +373,15 @@ where C: AsyncRead + AsyncWrite
match next_data_poll {
Ok(Async::Ready(Some(data))) => self.current_data = data.freeze(),
Ok(Async::Ready(None)) => return Ok(0),
Ok(Async::NotReady) => return Err(IoErrorKind::WouldBlock.into()),
Ok(Async::NotReady) => {
// There was no data packet in the buffer about this substream ; maybe it's
// because it has been closed.
if inner.opened_substreams.contains(&self.num) {
return Err(IoErrorKind::WouldBlock.into());
} else {
return Ok(0);
}
},
Err(err) => return Err(err),
}
}
@ -372,7 +425,7 @@ impl<C> AsyncWrite for Substream<C>
where C: AsyncRead + AsyncWrite
{
fn shutdown(&mut self) -> Poll<(), IoError> {
let elem = codec::Elem::Close {
let elem = codec::Elem::Reset {
substream_id: self.num,
endpoint: self.endpoint,
};
@ -386,7 +439,7 @@ impl<C> Drop for Substream<C>
where C: AsyncRead + AsyncWrite
{
fn drop(&mut self) {
let _ = self.shutdown();
clean_out_substream(&mut self.inner.lock(), self.num);
let _ = self.shutdown(); // TODO: this doesn't necessarily send the close message
self.inner.lock().buffer.retain(|elem| elem.substream_id() != self.num);
}
}