From 75499489456c7f3046cf354a03f26d30c8c7138f Mon Sep 17 00:00:00 2001 From: mattrutherford <44339188+mattrutherford@users.noreply.github.com> Date: Thu, 28 Mar 2019 22:34:53 +0000 Subject: [PATCH] Use bounded channels in transport (#987) * Implement DialFuture * Update with recommended changes to buffer size, `expect()` and `close()` --- core/src/transport/memory.rs | 69 +++++++++++++++++++++++++----------- 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/core/src/transport/memory.rs b/core/src/transport/memory.rs index b8ae0e4e..7f6c7819 100644 --- a/core/src/transport/memory.rs +++ b/core/src/transport/memory.rs @@ -29,19 +29,50 @@ use rw_stream_sink::RwStreamSink; use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64}; lazy_static! { - static ref HUB: Mutex>>> = Mutex::new(FnvHashMap::default()); + static ref HUB: Mutex>>> = Mutex::new(FnvHashMap::default()); } /// Transport that supports `/memory/N` multiaddresses. #[derive(Debug, Copy, Clone, Default)] pub struct MemoryTransport; +/// Connection to a `MemoryTransport` currently being opened. +pub struct DialFuture { + sender: mpsc::Sender>, + channel_to_send: Option>, + channel_to_return: Option>, +} + +impl Future for DialFuture { + type Item = Channel; + type Error = MemoryTransportError; + + fn poll(&mut self) -> Poll { + if let Some(c) = self.channel_to_send.take() { + match self.sender.start_send(c) { + Err(_) => return Err(MemoryTransportError::Unreachable), + Ok(AsyncSink::NotReady(t)) => { + self.channel_to_send = Some(t); + return Ok(Async::NotReady) + }, + _ => (), + } + } + match self.sender.close() { + Err(_) => Err(MemoryTransportError::Unreachable), + Ok(Async::NotReady) => Ok(Async::NotReady), + Ok(Async::Ready(_)) => Ok(Async::Ready(self.channel_to_return.take() + .expect("Future should not be polled again once complete"))), + } + } +} + impl Transport for MemoryTransport { type Output = Channel; type Error = MemoryTransportError; type Listener = Listener; type ListenerUpgrade = FutureResult; - type Dial = FutureResult; + type Dial = DialFuture; fn listen_on(self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), TransportError> { let port = if let Ok(port) = parse_memory_addr(&addr) { @@ -68,7 +99,7 @@ impl Transport for MemoryTransport { let actual_addr = Protocol::Memory(port.get()).into(); - let (tx, rx) = mpsc::unbounded(); + let (tx, rx) = mpsc::channel(2); match hub.entry(port) { Entry::Occupied(_) => return Err(TransportError::Other(MemoryTransportError::Unreachable)), Entry::Vacant(e) => e.insert(tx), @@ -82,7 +113,7 @@ impl Transport for MemoryTransport { Ok((listener, actual_addr)) } - fn dial(self, addr: Multiaddr) -> Result> { + fn dial(self, addr: Multiaddr) -> Result> { let port = if let Ok(port) = parse_memory_addr(&addr) { if let Some(port) = NonZeroU64::new(port) { port @@ -94,20 +125,18 @@ impl Transport for MemoryTransport { }; let hub = HUB.lock(); - let chan = if let Some(tx) = hub.get(&port) { - let (a_tx, a_rx) = mpsc::unbounded(); - let (b_tx, b_rx) = mpsc::unbounded(); - let a = RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx }); - let b = RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx }); - if tx.unbounded_send(b).is_err() { - return Err(TransportError::Other(MemoryTransportError::Unreachable)); - } - a - } else { - return Err(TransportError::Other(MemoryTransportError::Unreachable)); - }; + if let Some(sender) = hub.get(&port) { + let (a_tx, a_rx) = mpsc::channel(4096); + let (b_tx, b_rx) = mpsc::channel(4096); + Ok(DialFuture { + sender: sender.clone(), + channel_to_send: Some(RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx })), + channel_to_return: Some(RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx })), - Ok(future::ok(chan)) + }) + } else { + Err(TransportError::Other(MemoryTransportError::Unreachable)) + } } fn nat_traversal(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { @@ -145,7 +174,7 @@ pub struct Listener { /// Port we're listening on. port: NonZeroU64, /// Receives incoming connections. - receiver: mpsc::UnboundedReceiver>, + receiver: mpsc::Receiver>, } impl Stream for Listener { @@ -197,8 +226,8 @@ pub type Channel = RwStreamSink>; /// /// Implements `Sink` and `Stream`. pub struct Chan { - incoming: mpsc::UnboundedReceiver, - outgoing: mpsc::UnboundedSender, + incoming: mpsc::Receiver, + outgoing: mpsc::Sender, } impl Stream for Chan {