diff --git a/core/src/protocols_handler/node_handler.rs b/core/src/protocols_handler/node_handler.rs index 855c8c15..eadf9195 100644 --- a/core/src/protocols_handler/node_handler.rs +++ b/core/src/protocols_handler/node_handler.rs @@ -31,7 +31,7 @@ use crate::{ } }; use futures::prelude::*; -use std::{error, fmt, time::{Duration, Instant}}; +use std::{error, fmt, time::Duration}; use tokio_timer::{Delay, Timeout}; /// Prototype for a `NodeHandlerWrapper`. @@ -64,7 +64,7 @@ where negotiating_out: Vec::new(), queued_dial_upgrades: Vec::new(), unique_dial_upgrade_id: 0, - connection_shutdown: None, + shutdown: Shutdown::None, } } } @@ -85,7 +85,7 @@ where negotiating_out: Vec::new(), queued_dial_upgrades: Vec::new(), unique_dial_upgrade_id: 0, - connection_shutdown: None, + shutdown: Shutdown::None, } } } @@ -112,9 +112,26 @@ where queued_dial_upgrades: Vec<(u64, TProtoHandler::OutboundProtocol)>, /// Unique identifier assigned to each queued dial upgrade. unique_dial_upgrade_id: u64, - /// When a connection has been deemed useless, will contain `Some` with a `Delay` to when it - /// should be shut down. - connection_shutdown: Option, + /// The currently planned connection & handler shutdown. + shutdown: Shutdown, +} + +/// The options for a planned connection & handler shutdown. +/// +/// A shutdown is planned anew based on the the return value of +/// [`ProtocolsHandler::connection_keep_alive`] of the underlying handler +/// after every invocation of [`ProtocolsHandler::poll`]. +/// +/// A planned shutdown is always postponed for as long as there are ingoing +/// or outgoing substreams being negotiated, i.e. it is a graceful, "idle" +/// shutdown. +enum Shutdown { + /// No shutdown is planned. + None, + /// A shut down is planned as soon as possible. + Asap, + /// A shut down is planned for when a `Delay` has elapsed. + Later(Delay) } /// Error generated by the `NodeHandlerWrapper`. @@ -257,10 +274,12 @@ where // calls on `self.handler`. let poll_result = self.handler.poll()?; - self.connection_shutdown = match self.handler.connection_keep_alive() { - KeepAlive::Until(expiration) => Some(Delay::new(expiration)), - KeepAlive::Now => Some(Delay::new(Instant::now())), - KeepAlive::Forever => None, + // Ask the handler whether it wants the connection (and the handler itself) + // to be kept alive, which determines the planned shutdown, if any. + self.shutdown = match self.handler.connection_keep_alive() { + KeepAlive::Until(t) => Shutdown::Later(Delay::new(t)), + KeepAlive::Now => Shutdown::Asap, + KeepAlive::Forever => Shutdown::None }; match poll_result { @@ -282,20 +301,17 @@ where Async::NotReady => (), }; - // Check the `connection_shutdown`. - if let Some(mut connection_shutdown) = self.connection_shutdown.take() { - // If we're negotiating substreams, let's delay the closing. - if self.negotiating_in.is_empty() && self.negotiating_out.is_empty() { - match connection_shutdown.poll() { - Ok(Async::Ready(_)) | Err(_) => { - return Err(NodeHandlerWrapperError::UselessTimeout); - }, - Ok(Async::NotReady) => { - self.connection_shutdown = Some(connection_shutdown); - } + // Check if the connection (and handler) should be shut down. + // As long as we're still negotiating substreams, shutdown is always postponed. + if self.negotiating_in.is_empty() && self.negotiating_out.is_empty() { + match self.shutdown { + Shutdown::None => {}, + Shutdown::Asap => return Err(NodeHandlerWrapperError::UselessTimeout), + Shutdown::Later(ref mut delay) => match delay.poll() { + Ok(Async::Ready(_)) | Err(_) => + return Err(NodeHandlerWrapperError::UselessTimeout), + Ok(Async::NotReady) => {} } - } else { - self.connection_shutdown = Some(connection_shutdown); } } diff --git a/core/tests/raw_swarm_simult.rs b/core/tests/raw_swarm_simult.rs index 6c0fd2bd..d783f353 100644 --- a/core/tests/raw_swarm_simult.rs +++ b/core/tests/raw_swarm_simult.rs @@ -76,7 +76,7 @@ where } - fn connection_keep_alive(&self) -> KeepAlive { KeepAlive::Now } + fn connection_keep_alive(&self) -> KeepAlive { KeepAlive::Forever } fn poll(&mut self) -> Poll, Self::Error> { Ok(Async::NotReady) diff --git a/protocols/identify/src/identify.rs b/protocols/identify/src/identify.rs index 721efcbb..e7e60547 100644 --- a/protocols/identify/src/identify.rs +++ b/protocols/identify/src/identify.rs @@ -220,60 +220,60 @@ pub enum IdentifyEvent { #[cfg(test)] mod tests { use crate::{Identify, IdentifyEvent}; - use futures::prelude::*; - use libp2p_core::identity; + use futures::{future, prelude::*}; use libp2p_core::{ + identity, + PeerId, upgrade::{self, OutboundUpgradeExt, InboundUpgradeExt}, + muxing::StreamMuxer, Multiaddr, Swarm, Transport }; + use libp2p_tcp::TcpConfig; + use libp2p_secio::SecioConfig; + use libp2p_mplex::MplexConfig; use rand::Rng; - use std::io; + use std::{fmt, io}; + use tokio::runtime::current_thread; + + fn transport() -> (identity::PublicKey, impl Transport< + Output = (PeerId, impl StreamMuxer), + Listener = impl Send, + ListenerUpgrade = impl Send, + Dial = impl Send, + Error = impl fmt::Debug + > + Clone) { + let id_keys = identity::Keypair::generate_ed25519(); + let pubkey = id_keys.public(); + let transport = TcpConfig::new() + .nodelay(true) + .with_upgrade(SecioConfig::new(id_keys)) + .and_then(move |out, endpoint| { + let peer_id = out.remote_key.into_peer_id(); + let peer_id2 = peer_id.clone(); + let upgrade = MplexConfig::default() + .map_outbound(move |muxer| (peer_id, muxer)) + .map_inbound(move |muxer| (peer_id2, muxer)); + upgrade::apply(out.stream, upgrade, endpoint) + }); + (pubkey, transport) + } #[test] fn periodic_id_works() { - let node1_key = identity::Keypair::generate_ed25519(); - let node1_public_key = node1_key.public(); - let node2_key = identity::Keypair::generate_ed25519(); - let node2_public_key = node2_key.public(); - - let mut swarm1 = { - // TODO: make creating the transport more elegant ; literaly half of the code of the test - // is about creating the transport - let local_peer_id = node1_public_key.clone().into_peer_id(); - let transport = libp2p_tcp::TcpConfig::new() - .with_upgrade(libp2p_secio::SecioConfig::new(node1_key)) - .and_then(move |out, endpoint| { - let peer_id = out.remote_key.into_peer_id(); - let peer_id2 = peer_id.clone(); - let upgrade = libp2p_mplex::MplexConfig::default() - .map_outbound(move |muxer| (peer_id, muxer)) - .map_inbound(move |muxer| (peer_id2, muxer)); - upgrade::apply(out.stream, upgrade, endpoint) - }) - .map_err(|_| -> io::Error { panic!() }); - - Swarm::new(transport, Identify::new("a".to_string(), "b".to_string(), node1_public_key.clone()), local_peer_id) + let (mut swarm1, pubkey1) = { + let (pubkey, transport) = transport(); + let protocol = Identify::new("a".to_string(), "b".to_string(), pubkey.clone()); + let swarm = Swarm::new(transport, protocol, pubkey.clone().into_peer_id()); + (swarm, pubkey) }; - let mut swarm2 = { - // TODO: make creating the transport more elegant ; literaly half of the code of the test - // is about creating the transport - let local_peer_id = node2_public_key.clone().into(); - let transport = libp2p_tcp::TcpConfig::new() - .with_upgrade(libp2p_secio::SecioConfig::new(node2_key)) - .and_then(move |out, endpoint| { - let peer_id = out.remote_key.into_peer_id(); - let peer_id2 = peer_id.clone(); - let upgrade = libp2p_mplex::MplexConfig::default() - .map_outbound(move |muxer| (peer_id, muxer)) - .map_inbound(move |muxer| (peer_id2, muxer)); - upgrade::apply(out.stream, upgrade, endpoint) - }) - .map_err(|_| -> io::Error { panic!() }); - - Swarm::new(transport, Identify::new("c".to_string(), "d".to_string(), node2_public_key.clone()), local_peer_id) + let (mut swarm2, pubkey2) = { + let (pubkey, transport) = transport(); + let protocol = Identify::new("c".to_string(), "d".to_string(), pubkey.clone()); + let swarm = Swarm::new(transport, protocol, pubkey.clone().into_peer_id()); + (swarm, pubkey) }; let addr: Multiaddr = { @@ -282,51 +282,45 @@ mod tests { }; Swarm::listen_on(&mut swarm1, addr.clone()).unwrap(); - Swarm::dial_addr(&mut swarm2, addr).unwrap(); + Swarm::dial_addr(&mut swarm2, addr.clone()).unwrap(); - let mut swarm1_good = false; - let mut swarm2_good = false; - - tokio::runtime::current_thread::Runtime::new() - .unwrap() - .block_on(futures::future::poll_fn(move || -> Result<_, io::Error> { + // nb. Either swarm may receive the `Identified` event first, upon which + // it will permit the connection to be closed, as defined by + // `PeriodicIdHandler::connection_keep_alive`. Hence the test succeeds if + // either `Identified` event arrives correctly. + current_thread::Runtime::new().unwrap().block_on( + future::poll_fn(move || -> Result<_, io::Error> { loop { - let mut swarm1_not_ready = false; match swarm1.poll().unwrap() { Async::Ready(Some(IdentifyEvent::Identified { info, .. })) => { - assert_eq!(info.public_key, node2_public_key); + assert_eq!(info.public_key, pubkey2); assert_eq!(info.protocol_version, "c"); assert_eq!(info.agent_version, "d"); assert!(!info.protocols.is_empty()); assert!(info.listen_addrs.is_empty()); - swarm1_good = true; + return Ok(Async::Ready(())) }, Async::Ready(Some(IdentifyEvent::SendBack { result: Ok(()), .. })) => (), - Async::Ready(_) => panic!(), - Async::NotReady => swarm1_not_ready = true, + Async::Ready(e) => panic!("{:?}", e), + Async::NotReady => {} } match swarm2.poll().unwrap() { Async::Ready(Some(IdentifyEvent::Identified { info, .. })) => { - assert_eq!(info.public_key, node1_public_key); + assert_eq!(info.public_key, pubkey1); assert_eq!(info.protocol_version, "a"); assert_eq!(info.agent_version, "b"); assert!(!info.protocols.is_empty()); assert_eq!(info.listen_addrs.len(), 1); - swarm2_good = true; + return Ok(Async::Ready(())) }, Async::Ready(Some(IdentifyEvent::SendBack { result: Ok(()), .. })) => (), - Async::Ready(_) => panic!(), - Async::NotReady if swarm1_not_ready => break, - Async::NotReady => () + Async::Ready(e) => panic!("{:?}", e), + Async::NotReady => break } } - if swarm1_good && swarm2_good { - Ok(Async::Ready(())) - } else { - Ok(Async::NotReady) - } + Ok(Async::NotReady) })) .unwrap(); }