From 43e4d1f589e44ab2e5546571637cfafbbe7b5aeb Mon Sep 17 00:00:00 2001 From: Pierre Krieger Date: Mon, 18 Feb 2019 17:05:50 +0100 Subject: [PATCH] Rewrite the MemoryTransport to be similar to the TcpConfig (#951) * Rewrite the MemoryTransport to be similar to the TcpConfig * Add small test * Test and bug fixes --- core/Cargo.toml | 2 + core/src/nodes/listeners.rs | 12 +- core/src/nodes/raw_swarm/tests.rs | 14 +- core/src/transport/memory.rs | 257 +++++++++++++++++++----------- core/src/transport/mod.rs | 2 +- misc/multiaddr/src/protocol.rs | 24 ++- misc/multiaddr/tests/lib.rs | 2 +- 7 files changed, 197 insertions(+), 116 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index 7a85a6e6..a0d350ad 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -13,6 +13,7 @@ categories = ["network-programming", "asynchronous"] bs58 = "0.2.0" bytes = "0.4" fnv = "1.0" +lazy_static = "1.2" log = "0.4" multiaddr = { package = "parity-multiaddr", version = "0.1.0", path = "../misc/multiaddr" } multihash = { package = "parity-multihash", version = "0.1.0", path = "../misc/multihash" } @@ -21,6 +22,7 @@ futures = { version = "0.1", features = ["use_std"] } parking_lot = "0.7" protobuf = "2.3" quick-error = "1.2" +rand = "0.6" rw-stream-sink = { version = "0.1.0", path = "../misc/rw-stream-sink" } smallvec = "0.6" tokio-executor = "0.1.4" diff --git a/core/src/nodes/listeners.rs b/core/src/nodes/listeners.rs index 19764448..a63f41f4 100644 --- a/core/src/nodes/listeners.rs +++ b/core/src/nodes/listeners.rs @@ -311,12 +311,12 @@ mod tests { #[test] fn incoming_event() { - let (tx, rx) = transport::connector(); + let mem_transport = transport::MemoryTransport::default(); - let mut listeners = ListenersStream::new(rx); - listeners.listen_on("/memory".parse().unwrap()).unwrap(); + let mut listeners = ListenersStream::new(mem_transport); + let actual_addr = listeners.listen_on("/memory/0".parse().unwrap()).unwrap(); - let dial = tx.dial("/memory".parse().unwrap()).unwrap(); + let dial = mem_transport.dial(actual_addr.clone()).unwrap(); let future = listeners .into_future() @@ -324,8 +324,8 @@ mod tests { .and_then(|(event, _)| { match event { Some(ListenersEvent::Incoming { listen_addr, upgrade, send_back_addr }) => { - assert_eq!(listen_addr, "/memory".parse().unwrap()); - assert_eq!(send_back_addr, "/memory".parse().unwrap()); + assert_eq!(listen_addr, actual_addr); + assert_eq!(send_back_addr, actual_addr); upgrade.map(|_| ()).map_err(|_| panic!()) }, _ => panic!() diff --git a/core/src/nodes/raw_swarm/tests.rs b/core/src/nodes/raw_swarm/tests.rs index acc40306..e0139c30 100644 --- a/core/src/nodes/raw_swarm/tests.rs +++ b/core/src/nodes/raw_swarm/tests.rs @@ -64,7 +64,7 @@ fn nat_traversal_transforms_the_observed_address_according_to_the_transport_used let mut raw_swarm = RawSwarm::<_, _, _, Handler, _>::new(transport, PeerId::random()); let addr1 = "/ip4/127.0.0.1/tcp/1234".parse::().expect("bad multiaddr"); // An unrelated outside address is returned as-is, no transform - let outside_addr1 = "/memory".parse::().expect("bad multiaddr"); + let outside_addr1 = "/memory/0".parse::().expect("bad multiaddr"); let addr2 = "/ip4/127.0.0.2/tcp/1234".parse::().expect("bad multiaddr"); let outside_addr2 = "/ip4/127.0.0.2/tcp/1234".parse::().expect("bad multiaddr"); @@ -128,7 +128,7 @@ fn num_incoming_negotiated() { transport.set_initial_listener_state(ListenerState::Ok(Async::Ready(Some((peer_id, muxer))))); let mut swarm = RawSwarm::<_, _, _, Handler, _>::new(transport, PeerId::random()); - swarm.listen_on("/memory".parse().unwrap()).unwrap(); + swarm.listen_on("/memory/0".parse().unwrap()).unwrap(); // no incoming yet assert_eq!(swarm.incoming_negotiated().count(), 0); @@ -203,7 +203,7 @@ fn querying_for_pending_peer() { let peer_id = PeerId::random(); let peer = swarm.peer(peer_id.clone()); assert_matches!(peer, Peer::NotConnected(PeerNotConnected{ .. })); - let addr = "/memory".parse().expect("bad multiaddr"); + let addr = "/memory/0".parse().expect("bad multiaddr"); let pending_peer = peer.into_not_connected().unwrap().connect(addr, Handler::default()); assert_matches!(pending_peer, PeerPendingConnect { .. }); } @@ -255,7 +255,7 @@ fn poll_with_closed_listener() { transport.set_initial_listener_state(ListenerState::Ok(Async::Ready(None))); let mut swarm = RawSwarm::<_, _, _, Handler, _>::new(transport, PeerId::random()); - swarm.listen_on("/memory".parse().unwrap()).unwrap(); + swarm.listen_on("/memory/0".parse().unwrap()).unwrap(); let mut rt = Runtime::new().unwrap(); let swarm = Arc::new(Mutex::new(swarm)); @@ -274,7 +274,7 @@ fn unknown_peer_that_is_unreachable_yields_unknown_peer_dial_error() { let mut transport = DummyTransport::new(); transport.make_dial_fail(); let mut swarm = RawSwarm::<_, _, _, Handler, _>::new(transport, PeerId::random()); - let addr = "/memory".parse::().expect("bad multiaddr"); + let addr = "/memory/0".parse::().expect("bad multiaddr"); let handler = Handler::default(); let dial_result = swarm.dial(addr, handler); assert!(dial_result.is_ok()); @@ -311,7 +311,7 @@ fn known_peer_that_is_unreachable_yields_dial_error() { let mut swarm1 = swarm1.lock(); let peer = swarm1.peer(peer_id.clone()); assert_matches!(peer, Peer::NotConnected(PeerNotConnected{ .. })); - let addr = "/memory".parse::().expect("bad multiaddr"); + let addr = "/memory/0".parse::().expect("bad multiaddr"); let pending_peer = peer.into_not_connected().unwrap().connect(addr, Handler::default()); assert_matches!(pending_peer, PeerPendingConnect { .. }); } @@ -466,7 +466,7 @@ fn limit_incoming_connections() { let mut swarm = RawSwarm::<_, _, _, Handler, _>::new_with_incoming_limit( transport, PeerId::random(), Some(limit)); assert_eq!(swarm.incoming_limit(), Some(limit)); - swarm.listen_on("/memory".parse().unwrap()).unwrap(); + swarm.listen_on("/memory/0".parse().unwrap()).unwrap(); assert_eq!(swarm.incoming_negotiated().count(), 0); let swarm = Arc::new(Mutex::new(swarm)); diff --git a/core/src/transport/memory.rs b/core/src/transport/memory.rs index d379a4fd..6fbde86f 100644 --- a/core/src/transport/memory.rs +++ b/core/src/transport/memory.rs @@ -20,62 +20,98 @@ use crate::{Transport, transport::TransportError}; use bytes::{Bytes, IntoBuf}; -use futures::{future::{self, FutureResult}, prelude::*, stream, sync::mpsc}; +use fnv::FnvHashMap; +use futures::{future::{self, FutureResult}, prelude::*, sync::mpsc, try_ready}; +use lazy_static::lazy_static; use multiaddr::{Protocol, Multiaddr}; use parking_lot::Mutex; use rw_stream_sink::RwStreamSink; -use std::{error, fmt, sync::Arc}; +use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64}; -/// Builds a new pair of `Transport`s. The dialer can reach the listener by dialing `/memory`. -#[inline] -pub fn connector() -> (Dialer, Listener) { - let (tx, rx) = mpsc::unbounded(); - (Dialer(tx), Listener(Arc::new(Mutex::new(rx)))) +lazy_static! { + static ref HUB: Mutex>>> = Mutex::new(FnvHashMap::default()); } -/// Same as `connector()`, but allows customizing the type used for transmitting packets between -/// the two endpoints. -#[inline] -pub fn connector_custom_type() -> (Dialer, Listener) { - let (tx, rx) = mpsc::unbounded(); - (Dialer(tx), Listener(Arc::new(Mutex::new(rx)))) -} +/// Transport that supports `/memory/N` multiaddresses. +#[derive(Debug, Copy, Clone, Default)] +pub struct MemoryTransport; -/// Dialing end of the memory transport. -pub struct Dialer(mpsc::UnboundedSender>); - -impl Clone for Dialer { - fn clone(&self) -> Self { - Dialer(self.0.clone()) - } -} - -impl Transport for Dialer { - type Output = Channel; +impl Transport for MemoryTransport { + type Output = Channel; type Error = MemoryTransportError; - type Listener = Box + Send>; - type ListenerUpgrade = FutureResult; - type Dial = Box + Send>; + type Listener = Listener; + type ListenerUpgrade = FutureResult; + type Dial = FutureResult; fn listen_on(self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), TransportError> { - Err(TransportError::MultiaddrNotSupported(addr)) + let port = if let Ok(port) = parse_memory_addr(&addr) { + port + } else { + return Err(TransportError::MultiaddrNotSupported(addr)); + }; + + let mut hub = (&*HUB).lock(); + + let port = if let Some(port) = NonZeroU64::new(port) { + port + } else { + loop { + let port = match NonZeroU64::new(rand::random()) { + Some(p) => p, + None => continue, + }; + if !hub.contains_key(&port) { + break port; + } + } + }; + + let actual_addr = Protocol::Memory(port.get()).into(); + + let (tx, rx) = mpsc::unbounded(); + match hub.entry(port) { + Entry::Occupied(_) => return Err(TransportError::Other(MemoryTransportError::Unreachable)), + Entry::Vacant(e) => e.insert(tx), + }; + + let listener = Listener { + port, + receiver: rx, + }; + + Ok((listener, actual_addr)) } fn dial(self, addr: Multiaddr) -> Result> { - if !is_memory_addr(&addr) { - return Err(TransportError::MultiaddrNotSupported(addr)) - } - let (a_tx, a_rx) = mpsc::unbounded(); - let (b_tx, b_rx) = mpsc::unbounded(); - let a = Chan { incoming: a_rx, outgoing: b_tx }; - let b = Chan { incoming: b_rx, outgoing: a_tx }; - let future = self.0.send(b) - .map(move |_| a.into()) - .map_err(|_| MemoryTransportError::RemoteClosed); - Ok(Box::new(future)) + let port = if let Ok(port) = parse_memory_addr(&addr) { + if let Some(port) = NonZeroU64::new(port) { + port + } else { + return Err(TransportError::Other(MemoryTransportError::Unreachable)); + } + } else { + return Err(TransportError::MultiaddrNotSupported(addr)); + }; + + let hub = HUB.lock(); + let chan = if let Some(tx) = hub.get(&port) { + let (a_tx, a_rx) = mpsc::unbounded(); + let (b_tx, b_rx) = mpsc::unbounded(); + let a = RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx }); + let b = RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx }); + if tx.unbounded_send(b).is_err() { + return Err(TransportError::Other(MemoryTransportError::Unreachable)); + } + a + } else { + return Err(TransportError::Other(MemoryTransportError::Unreachable)); + }; + + Ok(future::ok(chan)) } fn nat_traversal(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { + // TODO: NAT traversal for `/memory` addresses? how does that make sense? if server == observed { Some(server.clone()) } else { @@ -87,76 +123,69 @@ impl Transport for Dialer { /// Error that can be produced from the `MemoryTransport`. #[derive(Debug, Copy, Clone)] pub enum MemoryTransportError { - /// The other side of the transport has been closed earlier. - RemoteClosed, + /// There's no listener on the given port. + Unreachable, + /// Tries to listen on a port that is already in use. + AlreadyInUse, } impl fmt::Display for MemoryTransportError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { - MemoryTransportError::RemoteClosed => - write!(f, "The other side of the memory transport has been closed."), + MemoryTransportError::Unreachable => write!(f, "No listener on the given port."), + MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."), } } } impl error::Error for MemoryTransportError {} -/// Receiving end of the memory transport. -pub struct Listener(Arc>>>); - -impl Clone for Listener { - fn clone(&self) -> Self { - Listener(self.0.clone()) - } +/// Listener for memory connections. +pub struct Listener { + /// Port we're listening on. + port: NonZeroU64, + /// Receives incoming connections. + receiver: mpsc::UnboundedReceiver>, } -impl Transport for Listener { - type Output = Channel; +impl Stream for Listener { + type Item = (FutureResult, MemoryTransportError>, Multiaddr); type Error = MemoryTransportError; - type Listener = Box + Send>; - type ListenerUpgrade = FutureResult; - type Dial = Box + Send>; - fn listen_on(self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), TransportError> { - if !is_memory_addr(&addr) { - return Err(TransportError::MultiaddrNotSupported(addr)); - } - let addr2 = addr.clone(); - let receiver = self.0.clone(); - let stream = stream::poll_fn(move || receiver.lock().poll()) - .map(move |channel| { - (future::ok(channel.into()), addr.clone()) - }) - .map_err(|()| unreachable!()); - Ok((Box::new(stream), addr2)) - } - - #[inline] - fn dial(self, addr: Multiaddr) -> Result> { - Err(TransportError::MultiaddrNotSupported(addr)) - } - - #[inline] - fn nat_traversal(&self, server: &Multiaddr, observed: &Multiaddr) -> Option { - if server == observed { - Some(server.clone()) - } else { - None - } + fn poll(&mut self) -> Poll, Self::Error> { + let channel = try_ready!(Ok(self.receiver.poll() + .expect("An unbounded receiver never panics; QED"))); + let channel = match channel { + Some(c) => c, + None => return Ok(Async::Ready(None)), + }; + let dialed_addr = Protocol::Memory(self.port.get()).into(); + Ok(Async::Ready(Some((future::ok(channel), dialed_addr)))) } } -/// Returns `true` if and only if the address is `/memory`. -fn is_memory_addr(a: &Multiaddr) -> bool { +impl Drop for Listener { + fn drop(&mut self) { + let val_in = HUB.lock().remove(&self.port); + debug_assert!(val_in.is_some()); + } +} + +/// If the address is `/memory/n`, returns the value of `n`. +fn parse_memory_addr(a: &Multiaddr) -> Result { let mut iter = a.iter(); - if iter.next() != Some(Protocol::Memory) { - return false; - } + + let port = if let Some(Protocol::Memory(port)) = iter.next() { + port + } else { + return Err(()); + }; + if iter.next().is_some() { - return false; + return Err(()); } - true + + Ok(port) } /// A channel represents an established, in-memory, logical connection between two endpoints. @@ -174,31 +203,31 @@ pub struct Chan { impl Stream for Chan { type Item = T; - type Error = MemoryTransportError; + type Error = io::Error; #[inline] fn poll(&mut self) -> Poll, Self::Error> { - self.incoming.poll().map_err(|()| MemoryTransportError::RemoteClosed) + self.incoming.poll().map_err(|()| io::ErrorKind::BrokenPipe.into()) } } impl Sink for Chan { type SinkItem = T; - type SinkError = MemoryTransportError; + type SinkError = io::Error; #[inline] fn start_send(&mut self, item: Self::SinkItem) -> StartSend { - self.outgoing.start_send(item).map_err(|_| MemoryTransportError::RemoteClosed) + self.outgoing.start_send(item).map_err(|_| io::ErrorKind::BrokenPipe.into()) } #[inline] fn poll_complete(&mut self) -> Poll<(), Self::SinkError> { - self.outgoing.poll_complete().map_err(|_| MemoryTransportError::RemoteClosed) + self.outgoing.poll_complete().map_err(|_| io::ErrorKind::BrokenPipe.into()) } #[inline] fn close(&mut self) -> Poll<(), Self::SinkError> { - self.outgoing.close().map_err(|_| MemoryTransportError::RemoteClosed) + self.outgoing.close().map_err(|_| io::ErrorKind::BrokenPipe.into()) } } @@ -208,3 +237,41 @@ impl Into>> for Chan { RwStreamSink::new(self) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_memory_addr_works() { + assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5)); + assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(())); + assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0)); + assert_eq!(parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()), Err(())); + assert_eq!(parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()), Err(())); + assert_eq!(parse_memory_addr(&"/memory/1234567890".parse().unwrap()), Ok(1_234_567_890)); + } + + #[test] + fn listening_twice() { + let transport = MemoryTransport::default(); + assert!(transport.listen_on("/memory/5".parse().unwrap()).is_ok()); + assert!(transport.listen_on("/memory/5".parse().unwrap()).is_ok()); + let _listener = transport.listen_on("/memory/5".parse().unwrap()).unwrap(); + assert!(transport.listen_on("/memory/5".parse().unwrap()).is_err()); + assert!(transport.listen_on("/memory/5".parse().unwrap()).is_err()); + drop(_listener); + assert!(transport.listen_on("/memory/5".parse().unwrap()).is_ok()); + assert!(transport.listen_on("/memory/5".parse().unwrap()).is_ok()); + } + + #[test] + fn port_not_in_use() { + let transport = MemoryTransport::default(); + assert!(transport.dial("/memory/5".parse().unwrap()).is_err()); + let _listener = transport.listen_on("/memory/5".parse().unwrap()).unwrap(); + assert!(transport.dial("/memory/5".parse().unwrap()).is_ok()); + } + + // TODO: test that is actually works +} diff --git a/core/src/transport/mod.rs b/core/src/transport/mod.rs index a8bdacbd..5aaed277 100644 --- a/core/src/transport/mod.rs +++ b/core/src/transport/mod.rs @@ -44,7 +44,7 @@ pub mod timeout; pub mod upgrade; pub use self::choice::OrTransport; -pub use self::memory::connector; +pub use self::memory::MemoryTransport; pub use self::upgrade::Upgrade; /// A transport is an object that can be used to produce connections by listening or dialing a diff --git a/misc/multiaddr/src/protocol.rs b/misc/multiaddr/src/protocol.rs index 8755a940..fb1f3500 100644 --- a/misc/multiaddr/src/protocol.rs +++ b/misc/multiaddr/src/protocol.rs @@ -52,7 +52,8 @@ pub enum Protocol<'a> { P2pWebRtcDirect, P2pWebRtcStar, P2pWebSocketStar, - Memory, + /// Contains the "port" to contact. Similar to TCP or UDP, 0 means "assign me a port". + Memory(u64), Onion(Cow<'a, [u8; 10]>, u16), P2p(Multihash), P2pCircuit, @@ -139,7 +140,10 @@ impl<'a> Protocol<'a> { "p2p-webrtc-star" => Ok(Protocol::P2pWebRtcStar), "p2p-webrtc-direct" => Ok(Protocol::P2pWebRtcDirect), "p2p-circuit" => Ok(Protocol::P2pCircuit), - "memory" => Ok(Protocol::Memory), + "memory" => { + let s = iter.next().ok_or(Error::InvalidProtocolString)?; + Ok(Protocol::Memory(s.parse()?)) + } _ => Err(Error::UnknownProtocolString) } } @@ -200,7 +204,12 @@ impl<'a> Protocol<'a> { P2P_WEBRTC_DIRECT => Ok((Protocol::P2pWebRtcDirect, input)), P2P_WEBRTC_STAR => Ok((Protocol::P2pWebRtcStar, input)), P2P_WEBSOCKET_STAR => Ok((Protocol::P2pWebSocketStar, input)), - MEMORY => Ok((Protocol::Memory, input)), + MEMORY => { + let (data, rest) = split_at(8, input)?; + let mut rdr = Cursor::new(data); + let num = rdr.read_u64::()?; + Ok((Protocol::Memory(num), rest)) + } ONION => { let (data, rest) = split_at(12, input)?; let port = BigEndian::read_u16(&data[10 ..]); @@ -315,7 +324,10 @@ impl<'a> Protocol<'a> { Protocol::P2pWebRtcStar => w.write_all(encode::u32(P2P_WEBRTC_STAR, &mut buf))?, Protocol::P2pWebRtcDirect => w.write_all(encode::u32(P2P_WEBRTC_DIRECT, &mut buf))?, Protocol::P2pCircuit => w.write_all(encode::u32(P2P_CIRCUIT, &mut buf))?, - Protocol::Memory => w.write_all(encode::u32(MEMORY, &mut buf))? + Protocol::Memory(port) => { + w.write_all(encode::u32(MEMORY, &mut buf))?; + w.write_u64::(*port)? + } } Ok(()) } @@ -334,7 +346,7 @@ impl<'a> Protocol<'a> { P2pWebRtcDirect => P2pWebRtcDirect, P2pWebRtcStar => P2pWebRtcStar, P2pWebSocketStar => P2pWebSocketStar, - Memory => Memory, + Memory(a) => Memory(a), Onion(addr, port) => Onion(Cow::Owned(addr.into_owned()), port), P2p(a) => P2p(a), P2pCircuit => P2pCircuit, @@ -365,7 +377,7 @@ impl<'a> fmt::Display for Protocol<'a> { P2pWebRtcDirect => f.write_str("/p2p-webrtc-direct"), P2pWebRtcStar => f.write_str("/p2p-webrtc-star"), P2pWebSocketStar => f.write_str("/p2p-websocket-star"), - Memory => f.write_str("/memory"), + Memory(port) => write!(f, "/memory/{}", port), Onion(addr, port) => { let s = BASE32.encode(addr.as_ref()); write!(f, "/onion/{}:{}", s.to_lowercase(), port) diff --git a/misc/multiaddr/tests/lib.rs b/misc/multiaddr/tests/lib.rs index e53b2eae..d42c8521 100644 --- a/misc/multiaddr/tests/lib.rs +++ b/misc/multiaddr/tests/lib.rs @@ -62,7 +62,7 @@ impl Arbitrary for Proto { 7 => Proto(P2pWebRtcDirect), 8 => Proto(P2pWebRtcStar), 9 => Proto(P2pWebSocketStar), - 10 => Proto(Memory), + 10 => Proto(Memory(g.gen())), // TODO: impl Arbitrary for Multihash: 11 => Proto(P2p(multihash("QmcgpsyWgH8Y8ajJz1Cu72KnS5uo2Aa2LpzU7kinSupNKC"))), 12 => Proto(P2pCircuit),