diff --git a/core/src/nodes/raw_swarm.rs b/core/src/nodes/raw_swarm.rs index 869d755a..4325974a 100644 --- a/core/src/nodes/raw_swarm.rs +++ b/core/src/nodes/raw_swarm.rs @@ -66,6 +66,9 @@ where /// The reach attempts of the swarm. /// This needs to be a separate struct in order to handle multiple mutable borrows issues. reach_attempts: ReachAttempts, + + /// Max numer of incoming connections. + incoming_limit: Option, } #[derive(Debug)] @@ -637,6 +640,25 @@ where other_reach_attempts: Vec::new(), connected_points: Default::default(), }, + incoming_limit: None, + } + } + + /// Creates a new node event stream with incoming connections limit. + #[inline] + pub fn new_with_incoming_limit(transport: TTrans, + local_peer_id: PeerId, incoming_limit: Option) -> Self + { + RawSwarm { + incoming_limit, + listeners: ListenersStream::new(transport), + active_nodes: CollectionStream::new(), + reach_attempts: ReachAttempts { + local_peer_id, + out_reach_attempts: Default::default(), + other_reach_attempts: Vec::new(), + connected_points: Default::default(), + }, } } @@ -658,6 +680,12 @@ where self.listeners.listeners() } + /// Returns limit on incoming connections. + #[inline] + pub fn incoming_limit(&self) -> Option { + self.incoming_limit + } + /// Call this function in order to know which address remotes should dial in order to access /// your local node. /// @@ -863,26 +891,39 @@ where ::OutboundOpenInfo: Send + 'static, // TODO: shouldn't be necessary THandlerErr: error::Error + Send + 'static, { - // Start by polling the listeners for events. - match self.listeners.poll() { - Async::NotReady => (), - Async::Ready(ListenersEvent::Incoming { upgrade, listen_addr, send_back_addr }) => { - let event = IncomingConnectionEvent { - upgrade, - local_peer_id: self.reach_attempts.local_peer_id.clone(), - listen_addr, - send_back_addr, - active_nodes: &mut self.active_nodes, - other_reach_attempts: &mut self.reach_attempts.other_reach_attempts, - }; - return Async::Ready(RawSwarmEvent::IncomingConnection(event)); - } - Async::Ready(ListenersEvent::Closed { listen_addr, listener, result }) => { - return Async::Ready(RawSwarmEvent::ListenerClosed { - listen_addr, - listener, - result, - }); + // Start by polling the listeners for events, but only + // if numer of incoming connection does not exceed the limit. + match self.incoming_limit { + Some(x) if self.incoming_negotiated().count() >= (x as usize) + => (), + _ => { + match self.listeners.poll() { + Async::NotReady => (), + Async::Ready(ListenersEvent::Incoming { + upgrade, listen_addr, send_back_addr }) => + { + let event = IncomingConnectionEvent { + upgrade, + local_peer_id: + self.reach_attempts.local_peer_id.clone(), + listen_addr, + send_back_addr, + active_nodes: &mut self.active_nodes, + other_reach_attempts: &mut self.reach_attempts.other_reach_attempts, + }; + return Async::Ready(RawSwarmEvent::IncomingConnection(event)); + }, + Async::Ready(ListenersEvent::Closed { + listen_addr, listener, result }) => + { + return Async::Ready(RawSwarmEvent::ListenerClosed { + listen_addr, + listener, + result, + }); + } + + } } } diff --git a/core/src/nodes/raw_swarm/tests.rs b/core/src/nodes/raw_swarm/tests.rs index 49c0e425..992879f2 100644 --- a/core/src/nodes/raw_swarm/tests.rs +++ b/core/src/nodes/raw_swarm/tests.rs @@ -454,3 +454,48 @@ fn local_prio_equivalence_relation() { assert_ne!(has_dial_prio(&a, &b), has_dial_prio(&b, &a)); } } + +#[test] +fn limit_incoming_connections() { + let mut transport = DummyTransport::new(); + let peer_id = PeerId::random(); + let muxer = DummyMuxer::new(); + let limit = 1; + transport.set_initial_listener_state(ListenerState::Ok(Async::Ready( + Some((peer_id, muxer))))); + let mut swarm = RawSwarm::<_, _, _, Handler, _>::new_with_incoming_limit( + transport, PeerId::random(), Some(limit)); + assert_eq!(swarm.incoming_limit(), Some(limit)); + swarm.listen_on("/memory".parse().unwrap()).unwrap(); + assert_eq!(swarm.incoming_negotiated().count(), 0); + + let swarm = Arc::new(Mutex::new(swarm)); + let mut rt = Runtime::new().unwrap(); + for i in 1..10 { + let swarm_fut = swarm.clone(); + let fut = future::poll_fn(move || -> Poll<_, ()> { + let mut swarm_fut = swarm_fut.lock(); + if i <= limit { + assert_matches!(swarm_fut.poll(), + Async::Ready(RawSwarmEvent::IncomingConnection(incoming)) => { + incoming.accept(Handler::default()); + }); + } else { + match swarm_fut.poll() { + Async::NotReady => (), + Async::Ready(x) => { + match x { + RawSwarmEvent::IncomingConnection(_) => (), + RawSwarmEvent::Connected { .. } => (), + _ => { panic!("Not expected event") }, + } + }, + } + } + Ok(Async::Ready(())) + }); + rt.block_on(fut).expect("tokio works"); + let swarm = swarm.lock(); + assert!(swarm.incoming_negotiated().count() <= (limit as usize)); + } +} diff --git a/core/src/swarm.rs b/core/src/swarm.rs index e073fdb2..857dc279 100644 --- a/core/src/swarm.rs +++ b/core/src/swarm.rs @@ -492,3 +492,170 @@ pub enum NetworkBehaviourAction { address: Multiaddr, }, } + +pub struct SwarmBuilder +where TTransport: Transport, + TBehaviour: NetworkBehaviour +{ + incoming_limit: Option, + topology: TTopology, + transport: TTransport, + behaviour: TBehaviour, +} + +impl SwarmBuilder +where TBehaviour: NetworkBehaviour, + TMuxer: StreamMuxer + Send + Sync + 'static, + ::OutboundSubstream: Send + 'static, + ::Substream: Send + 'static, + TTransport: Transport + Clone, + TTransport::Error: Send + 'static, + TTransport::Listener: Send + 'static, + TTransport::ListenerUpgrade: Send + 'static, + TTransport::Dial: Send + 'static, + TBehaviour::ProtocolsHandler: Send + 'static, + ::Handler: ProtocolsHandler> + Send + 'static, + <::Handler as ProtocolsHandler>::InEvent: Send + 'static, + <::Handler as ProtocolsHandler>::OutEvent: Send + 'static, + <::Handler as ProtocolsHandler>::Error: Send + 'static, + <::Handler as ProtocolsHandler>::OutboundOpenInfo: Send + 'static, // TODO: shouldn't be necessary + <::Handler as ProtocolsHandler>::InboundProtocol: InboundUpgrade> + Send + 'static, + <<::Handler as ProtocolsHandler>::InboundProtocol as UpgradeInfo>::Info: Send + 'static, + <<::Handler as ProtocolsHandler>::InboundProtocol as UpgradeInfo>::InfoIter: Send + 'static, + <<<::Handler as ProtocolsHandler>::InboundProtocol as UpgradeInfo>::InfoIter as IntoIterator>::IntoIter: Send + 'static, + <<::Handler as ProtocolsHandler>::InboundProtocol as InboundUpgrade>>::Error: fmt::Debug + Send + 'static, + <<::Handler as ProtocolsHandler>::InboundProtocol as InboundUpgrade>>::Future: Send + 'static, + <::Handler as ProtocolsHandler>::OutboundProtocol: OutboundUpgrade> + Send + 'static, + <<::Handler as ProtocolsHandler>::OutboundProtocol as UpgradeInfo>::Info: Send + 'static, + <<::Handler as ProtocolsHandler>::OutboundProtocol as UpgradeInfo>::InfoIter: Send + 'static, + <<<::Handler as ProtocolsHandler>::OutboundProtocol as UpgradeInfo>::InfoIter as IntoIterator>::IntoIter: Send + 'static, + <<::Handler as ProtocolsHandler>::OutboundProtocol as OutboundUpgrade>>::Future: Send + 'static, + <<::Handler as ProtocolsHandler>::OutboundProtocol as OutboundUpgrade>>::Error: fmt::Debug + Send + 'static, + ::Handler> as NodeHandler>::OutboundOpenInfo: Send + 'static, // TODO: shouldn't be necessary + TTopology: Topology, + +{ + pub fn new(transport: TTransport, behaviour: TBehaviour, + topology:TTopology) -> Self { + SwarmBuilder { + incoming_limit: None, + transport: transport, + topology: topology, + behaviour: behaviour, + } + } + + pub fn incoming_limit(mut self, incoming_limit: Option) -> Self + { + self.incoming_limit = incoming_limit; + self + } + + pub fn build(mut self) -> + Swarm + { + let supported_protocols = self.behaviour + .new_handler() + .into_handler(self.topology.local_peer_id()) + .listen_protocol() + .protocol_info() + .into_iter() + .map(|info| info.protocol_name().to_vec()) + .collect(); + let raw_swarm = RawSwarm::new_with_incoming_limit(self.transport, + self.topology.local_peer_id().clone(), + self.incoming_limit); + Swarm { + raw_swarm, + behaviour: self.behaviour, + topology: self.topology, + supported_protocols, + listened_addrs: SmallVec::new(), + } + } +} + +#[cfg(test)] +mod tests { + + use crate::nodes::raw_swarm::RawSwarm; + use crate::peer_id::PeerId; + use crate::protocols_handler::{DummyProtocolsHandler, ProtocolsHandler}; + use crate::public_key::PublicKey; + use crate::tests::dummy_transport::DummyTransport; + use crate::topology::MemoryTopology; + use futures::prelude::*; + use rand::random; + use smallvec::SmallVec; + use std::marker::PhantomData; + use super::{ConnectedPoint, NetworkBehaviour, NetworkBehaviourAction, + PollParameters, Swarm, SwarmBuilder}; + use tokio_io::{AsyncRead, AsyncWrite}; + use void::Void; + + #[derive(Clone)] + struct DummyBehaviour { + marker: PhantomData, + } + + trait TSubstream: AsyncRead + AsyncWrite {} + + impl NetworkBehaviour + for DummyBehaviour + where TSubstream: AsyncRead + AsyncWrite + { + type ProtocolsHandler = DummyProtocolsHandler; + type OutEvent = Void; + + fn new_handler(&mut self) -> Self::ProtocolsHandler { + DummyProtocolsHandler::default() + } + + fn inject_connected(&mut self, _: PeerId, _: ConnectedPoint) {} + + fn inject_disconnected(&mut self, _: &PeerId, _: ConnectedPoint) {} + + fn inject_node_event(&mut self, _: PeerId, + _: ::OutEvent) {} + + fn poll(&mut self, _:&mut PollParameters) -> + Async::InEvent, Self::OutEvent>> + { + Async::NotReady + } + + } + + fn get_random_id() -> PublicKey { + PublicKey::Rsa((0 .. 2048) + .map(|_| -> u8 { random() }) + .collect() + ) + } + + #[test] + fn test_build_swarm() { + let id = get_random_id(); + let transport = DummyTransport::new(); + let topology = MemoryTopology::empty(id); + let behaviour = DummyBehaviour{marker: PhantomData}; + let swarm = SwarmBuilder::new(transport, behaviour, + topology).incoming_limit(Some(4)).build(); + assert_eq!(swarm.raw_swarm.incoming_limit(), Some(4)); + } + + #[test] + fn test_build_swarm_with_max_listeners_none() { + let id = get_random_id(); + let transport = DummyTransport::new(); + let topology = MemoryTopology::empty(id); + let behaviour = DummyBehaviour{marker: PhantomData}; + let swarm = SwarmBuilder::new(transport, behaviour, topology) + .build(); + assert!(swarm.raw_swarm.incoming_limit().is_none()) + + } + + +}