From 7d1a0238dea8788a80371cb8bc2300a8c23088e5 Mon Sep 17 00:00:00 2001 From: Toralf Wittner Date: Mon, 15 Oct 2018 10:42:11 +0200 Subject: [PATCH] Add shutdown functionality to `NodeStream`. (#560) Add shutdown functionality to `NodeStream`. Add `NodeStream::shutdown` to allow triggering the shutdown process, and `NodeStream::poll_shutdown` as the internal way to drive any potential shutdown to completion. --- core/src/muxing.rs | 1 + core/src/nodes/handled_node.rs | 116 +++++++++++----------- core/src/nodes/node.rs | 175 ++++++++++++++++++++++++--------- 3 files changed, 186 insertions(+), 106 deletions(-) diff --git a/core/src/muxing.rs b/core/src/muxing.rs index 46dc28b8..f07088bc 100644 --- a/core/src/muxing.rs +++ b/core/src/muxing.rs @@ -28,6 +28,7 @@ use std::sync::atomic::{AtomicUsize, Ordering}; use tokio_io::{AsyncRead, AsyncWrite}; /// Ways to shutdown a substream or stream muxer. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Shutdown { /// Shutdown inbound direction. Inbound, diff --git a/core/src/nodes/handled_node.rs b/core/src/nodes/handled_node.rs index 65c5bac7..456d7827 100644 --- a/core/src/nodes/handled_node.rs +++ b/core/src/nodes/handled_node.rs @@ -20,7 +20,7 @@ use muxing::StreamMuxer; use nodes::node::{NodeEvent, NodeStream, Substream}; -use futures::prelude::*; +use futures::{prelude::*, stream::Fuse}; use std::io::Error as IoError; use Multiaddr; @@ -124,10 +124,12 @@ where TMuxer: StreamMuxer, THandler: NodeHandler>, { - /// Node that handles the muxing. Can be `None` if the handled node is shutting down. - node: Option>, + /// Node that handles the muxing. + node: Fuse>, /// Handler that processes substreams. handler: THandler, + // True, if the node is shutting down. + is_shutting_down: bool } impl HandledNode @@ -140,8 +142,9 @@ where #[inline] pub fn new(muxer: TMuxer, multiaddr_future: TAddrFut, handler: THandler) -> Self { HandledNode { - node: Some(NodeStream::new(muxer, multiaddr_future)), + node: NodeStream::new(muxer, multiaddr_future).fuse(), handler, + is_shutting_down: false } } @@ -151,26 +154,26 @@ where self.handler.inject_event(event); } - /// Returns true if the inbound channel of the muxer is closed. + /// Returns true if the inbound channel of the muxer is open. /// - /// If `true` is returned, then no more inbound substream will be received. + /// If `true` is returned, more inbound substream will be received. #[inline] - pub fn is_inbound_closed(&self) -> bool { - self.node.as_ref().map(|n| n.is_inbound_closed()).unwrap_or(true) + pub fn is_inbound_open(&self) -> bool { + self.node.get_ref().is_inbound_open() } - /// Returns true if the outbound channel of the muxer is closed. + /// Returns true if the outbound channel of the muxer is open. /// - /// If `true` is returned, then no more outbound substream will be opened. + /// If `true` is returned, more outbound substream will be opened. #[inline] - pub fn is_outbound_closed(&self) -> bool { - self.node.as_ref().map(|n| n.is_outbound_closed()).unwrap_or(true) + pub fn is_outbound_open(&self) -> bool { + self.node.get_ref().is_outbound_open() } /// Returns true if the handled node is in the process of shutting down. #[inline] pub fn is_shutting_down(&self) -> bool { - self.node.is_none() + self.is_shutting_down } /// Indicates to the handled node that it should shut down. After calling this method, the @@ -178,13 +181,14 @@ where /// /// After this method returns, `is_shutting_down()` should return true. pub fn shutdown(&mut self) { - if let Some(node) = self.node.take() { - for user_data in node.close() { - self.handler.inject_outbound_closed(user_data); - } + self.node.get_mut().shutdown_all(); + self.is_shutting_down = true; + + for user_data in self.node.get_mut().cancel_outgoing() { + self.handler.inject_outbound_closed(user_data); } - self.handler.shutdown(); + self.handler.shutdown() } } @@ -201,60 +205,54 @@ where loop { let mut node_not_ready = false; - match self.node.as_mut().map(|n| n.poll()) { - Some(Ok(Async::NotReady)) | None => {}, - Some(Ok(Async::Ready(Some(NodeEvent::InboundSubstream { substream })))) => { - self.handler.inject_substream(substream, NodeHandlerEndpoint::Listener); - }, - Some(Ok(Async::Ready(Some(NodeEvent::OutboundSubstream { user_data, substream })))) => { + match self.node.poll()? { + Async::NotReady => (), + Async::Ready(Some(NodeEvent::InboundSubstream { substream })) => { + self.handler.inject_substream(substream, NodeHandlerEndpoint::Listener) + } + Async::Ready(Some(NodeEvent::OutboundSubstream { user_data, substream })) => { let endpoint = NodeHandlerEndpoint::Dialer(user_data); - self.handler.inject_substream(substream, endpoint); - }, - Some(Ok(Async::Ready(None))) => { + self.handler.inject_substream(substream, endpoint) + } + Async::Ready(None) => { node_not_ready = true; - self.node = None; - self.handler.shutdown(); - }, - Some(Ok(Async::Ready(Some(NodeEvent::Multiaddr(result))))) => { - self.handler.inject_multiaddr(result); - }, - Some(Ok(Async::Ready(Some(NodeEvent::OutboundClosed { user_data })))) => { - self.handler.inject_outbound_closed(user_data); - }, - Some(Ok(Async::Ready(Some(NodeEvent::InboundClosed)))) => { - self.handler.inject_inbound_closed(); - }, - Some(Err(err)) => { - self.node = None; - return Err(err); - }, + if !self.is_shutting_down { + self.handler.shutdown() + } + } + Async::Ready(Some(NodeEvent::Multiaddr(result))) => { + self.handler.inject_multiaddr(result) + } + Async::Ready(Some(NodeEvent::OutboundClosed { user_data })) => { + self.handler.inject_outbound_closed(user_data) + } + Async::Ready(Some(NodeEvent::InboundClosed)) => { + self.handler.inject_inbound_closed() + } } - match self.handler.poll() { - Ok(Async::NotReady) => { + match self.handler.poll()? { + Async::NotReady => { if node_not_ready { - break; + break } - }, - Ok(Async::Ready(Some(NodeHandlerEvent::OutboundSubstreamRequest(user_data)))) => { - if let Some(node) = self.node.as_mut() { - match node.open_substream(user_data) { + } + Async::Ready(Some(NodeHandlerEvent::OutboundSubstreamRequest(user_data))) => { + if self.node.get_ref().is_outbound_open() { + match self.node.get_mut().open_substream(user_data) { Ok(()) => (), Err(user_data) => self.handler.inject_outbound_closed(user_data), } } else { self.handler.inject_outbound_closed(user_data); } - }, - Ok(Async::Ready(Some(NodeHandlerEvent::Custom(event)))) => { + } + Async::Ready(Some(NodeHandlerEvent::Custom(event))) => { return Ok(Async::Ready(Some(event))); - }, - Ok(Async::Ready(None)) => { - return Ok(Async::Ready(None)); - }, - Err(err) => { - return Err(err); - }, + } + Async::Ready(None) => { + return Ok(Async::Ready(None)) + } } } diff --git a/core/src/nodes/node.rs b/core/src/nodes/node.rs index b22384ab..10c2098b 100644 --- a/core/src/nodes/node.rs +++ b/core/src/nodes/node.rs @@ -59,10 +59,10 @@ where { /// The muxer used to manage substreams. muxer: Arc, - /// If true, the inbound side of the muxer has closed earlier and should no longer be polled. - inbound_finished: bool, - /// If true, the outbound side of the muxer has closed earlier. - outbound_finished: bool, + /// Tracks the state of the muxers inbound direction. + inbound_state: StreamState, + /// Tracks the state of the muxers outbound direction. + outbound_state: StreamState, /// Address of the node ; can be empty if the address hasn't been resolved yet. address: Addr, /// List of substreams we are currently opening. @@ -83,6 +83,19 @@ enum Addr { /// A successfully opened substream. pub type Substream = muxing::SubstreamRef>; +// Track state of stream muxer per direction. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum StreamState { + // direction is open + Open, + // direction is shutting down + Shutdown, + // direction has shutdown and is flushing + Flush, + // direction is closed + Closed +} + /// Event that can happen on the `NodeStream`. #[derive(Debug)] pub enum NodeEvent @@ -134,8 +147,8 @@ where pub fn new(muxer: TMuxer, multiaddr_future: TAddrFut) -> Self { NodeStream { muxer: Arc::new(muxer), - inbound_finished: false, - outbound_finished: false, + inbound_state: StreamState::Open, + outbound_state: StreamState::Open, address: Addr::Future(multiaddr_future), outbound_substreams: SmallVec::new(), } @@ -161,7 +174,7 @@ where /// `OutboundSubstream` event or an `OutboundClosed` event containing the user data that has /// been passed to this method. pub fn open_substream(&mut self, user_data: TUserData) -> Result<(), TUserData> { - if self.outbound_finished { + if self.outbound_state != StreamState::Open { return Err(user_data); } @@ -171,25 +184,30 @@ where Ok(()) } - /// Returns true if the inbound channel of the muxer is closed. + /// Returns true if the inbound channel of the muxer is open. /// - /// If `true` is returned, then no more inbound substream will be produced. + /// If `true` is returned, more inbound substream will be produced. #[inline] - pub fn is_inbound_closed(&self) -> bool { - self.inbound_finished + pub fn is_inbound_open(&self) -> bool { + self.inbound_state == StreamState::Open } - /// Returns true if the outbound channel of the muxer is closed. + /// Returns true if the outbound channel of the muxer is open. /// - /// If `true` is returned, then no more outbound substream can be opened. Calling + /// If `true` is returned, more outbound substream can be opened. Otherwise, calling /// `open_substream` will return an `Err`. #[inline] - pub fn is_outbound_closed(&self) -> bool { - self.outbound_finished + pub fn is_outbound_open(&self) -> bool { + self.outbound_state == StreamState::Open } /// Destroys the node stream and returns all the pending outbound substreams. pub fn close(mut self) -> Vec { + self.cancel_outgoing() + } + + /// Destroys all outbound streams and returns the corresponding user data. + pub fn cancel_outgoing(&mut self) -> Vec { let mut out = Vec::with_capacity(self.outbound_substreams.len()); for (user_data, outbound) in self.outbound_substreams.drain() { out.push(user_data); @@ -197,6 +215,75 @@ where } out } + + /// Trigger node shutdown. + /// + /// After this, `NodeStream::poll` will eventually produce `None`, when both endpoints are + /// closed. + pub fn shutdown_all(&mut self) { + if self.inbound_state == StreamState::Open { + self.inbound_state = StreamState::Shutdown + } + if self.outbound_state == StreamState::Open { + self.outbound_state = StreamState::Shutdown + } + } + + // If in progress, drive this node's stream muxer shutdown to completion. + fn poll_shutdown(&mut self) -> Poll<(), IoError> { + use self::StreamState::*; + loop { + match (self.inbound_state, self.outbound_state) { + (Open, Open) | (Open, Closed) | (Closed, Open) | (Closed, Closed) => { + return Ok(Async::Ready(())) + } + (Shutdown, Shutdown) => { + if let Async::Ready(()) = self.muxer.shutdown(muxing::Shutdown::All)? { + self.inbound_state = StreamState::Flush; + self.outbound_state = StreamState::Flush; + continue + } + return Ok(Async::NotReady) + } + (Shutdown, _) => { + if let Async::Ready(()) = self.muxer.shutdown(muxing::Shutdown::Inbound)? { + self.inbound_state = StreamState::Flush; + continue + } + return Ok(Async::NotReady) + } + (_, Shutdown) => { + if let Async::Ready(()) = self.muxer.shutdown(muxing::Shutdown::Outbound)? { + self.outbound_state = StreamState::Flush; + continue + } + return Ok(Async::NotReady) + } + (Flush, Open) => { + if let Async::Ready(()) = self.muxer.flush_all()? { + self.inbound_state = StreamState::Closed; + continue + } + return Ok(Async::NotReady) + } + (Open, Flush) => { + if let Async::Ready(()) = self.muxer.flush_all()? { + self.outbound_state = StreamState::Closed; + continue + } + return Ok(Async::NotReady) + } + (Flush, Flush) | (Flush, Closed) | (Closed, Flush) => { + if let Async::Ready(()) = self.muxer.flush_all()? { + self.inbound_state = StreamState::Closed; + self.outbound_state = StreamState::Closed; + continue + } + return Ok(Async::NotReady) + } + } + } + } } impl Stream for NodeStream @@ -208,21 +295,25 @@ where type Error = IoError; fn poll(&mut self) -> Poll, Self::Error> { + // Drive the shutdown process, if any. + if self.poll_shutdown()?.is_not_ready() { + return Ok(Async::NotReady) + } + // Polling inbound substream. - if !self.inbound_finished { - match self.muxer.poll_inbound() { - Ok(Async::Ready(Some(substream))) => { + if self.inbound_state == StreamState::Open { + match self.muxer.poll_inbound()? { + Async::Ready(Some(substream)) => { let substream = muxing::substream_from_ref(self.muxer.clone(), substream); return Ok(Async::Ready(Some(NodeEvent::InboundSubstream { substream, }))); } - Ok(Async::Ready(None)) => { - self.inbound_finished = true; + Async::Ready(None) => { + self.inbound_state = StreamState::Closed; return Ok(Async::Ready(Some(NodeEvent::InboundClosed))); } - Ok(Async::NotReady) => {} - Err(err) => return Err(err), + Async::NotReady => {} } } @@ -240,7 +331,7 @@ where }))); } Ok(Async::Ready(None)) => { - self.outbound_finished = true; + self.outbound_state = StreamState::Closed; self.muxer.destroy_outbound(outbound); return Ok(Async::Ready(Some(NodeEvent::OutboundClosed { user_data }))); } @@ -275,8 +366,11 @@ where } // Closing the node if there's no way we can do anything more. - if self.inbound_finished && self.outbound_finished && self.outbound_substreams.is_empty() { - return Ok(Async::Ready(None)); + if self.inbound_state == StreamState::Closed + && self.outbound_state == StreamState::Closed + && self.outbound_substreams.is_empty() + { + return Ok(Async::Ready(None)) } // Nothing happened. Register our task to be notified and return. @@ -292,8 +386,8 @@ where fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> { f.debug_struct("NodeStream") .field("address", &self.multiaddr()) - .field("inbound_finished", &self.inbound_finished) - .field("outbound_finished", &self.outbound_finished) + .field("inbound_state", &self.inbound_state) + .field("outbound_state", &self.outbound_state) .field("outbound_substreams", &self.outbound_substreams.len()) .finish() } @@ -310,19 +404,6 @@ where for (_, outbound) in self.outbound_substreams.drain() { self.muxer.destroy_outbound(outbound); } - // TODO: Maybe the shutdown logic should not be part of the destructor? - match (self.inbound_finished, self.outbound_finished) { - (true, true) => {} - (true, false) => { - let _ = self.muxer.shutdown(muxing::Shutdown::Outbound); - } - (false, true) => { - let _ = self.muxer.shutdown(muxing::Shutdown::Inbound); - } - (false, false) => { - let _ = self.muxer.shutdown(muxing::Shutdown::All); - } - } } } @@ -402,7 +483,7 @@ mod node_stream { }) }); - // Opening a second substream fails because `outbound_finished` is now true + // Opening a second substream fails because `outbound_state` is no longer open. assert_matches!(ns.open_substream(vec![22]), Err(user_data) => { assert_eq!(user_data, vec![22]); }); @@ -411,8 +492,8 @@ mod node_stream { #[test] fn query_inbound_outbound_state() { let ns = build_node_stream(); - assert_eq!(ns.is_inbound_closed(), false); - assert_eq!(ns.is_outbound_closed(), false); + assert!(ns.is_inbound_open()); + assert!(ns.is_outbound_open()); } #[test] @@ -426,7 +507,7 @@ mod node_stream { assert_matches!(node_event, NodeEvent::InboundClosed) }); - assert_eq!(ns.is_inbound_closed(), true); + assert!(!ns.is_inbound_open()); } #[test] @@ -436,7 +517,7 @@ mod node_stream { muxer.set_outbound_connection_state(DummyConnectionState::Closed); let mut ns = NodeStream::<_, _, Vec>::new(muxer, addr); - assert_eq!(ns.is_outbound_closed(), false); + assert!(ns.is_outbound_open()); ns.open_substream(vec![1]).unwrap(); let poll_result = ns.poll(); @@ -447,7 +528,7 @@ mod node_stream { }) }); - assert_eq!(ns.is_outbound_closed(), true, "outbound connection should be closed after polling"); + assert!(!ns.is_outbound_open(), "outbound connection should be closed after polling"); } #[test] @@ -552,7 +633,7 @@ mod node_stream { ns.open_substream(vec![1]).unwrap(); ns.poll().unwrap(); // poll past inbound ns.poll().unwrap(); // poll outbound - assert_eq!(ns.is_outbound_closed(), false); + assert!(ns.is_outbound_open()); assert!(format!("{:?}", ns).contains("outbound_substreams: 1")); }