diff --git a/core/CHANGELOG.md b/core/CHANGELOG.md index 5f85bcf2..e96957b5 100644 --- a/core/CHANGELOG.md +++ b/core/CHANGELOG.md @@ -16,6 +16,10 @@ two peer IDs are equal if and only if they use the same hash algorithm and have the same hash digest. [PR 1608](https://github.com/libp2p/rust-libp2p/pull/1608). +- Return dialer address instead of listener address as `remote_addr` in + `MemoryTransport` `Listener` `ListenerEvent::Upgrade` + [PR 1724](https://github.com/libp2p/rust-libp2p/pull/1724). + # 0.21.0 [2020-08-18] - Remove duplicates when performing address translation diff --git a/core/src/connection/listeners.rs b/core/src/connection/listeners.rs index bfd16a77..dab7b4fb 100644 --- a/core/src/connection/listeners.rs +++ b/core/src/connection/listeners.rs @@ -400,7 +400,7 @@ mod tests { match listeners.next().await.unwrap() { ListenersEvent::Incoming { local_addr, send_back_addr, .. } => { assert_eq!(local_addr, address); - assert_eq!(send_back_addr, address); + assert!(send_back_addr != address); }, _ => panic!() } diff --git a/core/src/transport/memory.rs b/core/src/transport/memory.rs index 8b7318c2..69be07d4 100644 --- a/core/src/transport/memory.rs +++ b/core/src/transport/memory.rs @@ -28,8 +28,57 @@ use rw_stream_sink::RwStreamSink; use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin}; lazy_static! { - static ref HUB: Mutex>>>> = - Mutex::new(FnvHashMap::default()); + static ref HUB: Hub = Hub(Mutex::new(FnvHashMap::default())); +} + +struct Hub(Mutex>); + +/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the +/// port of the dialer to a [`Listener`]. +type ChannelSender = mpsc::Sender<(Channel>, NonZeroU64)>; + +/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and +/// the port of the dialer from a [`DialFuture`]. +type ChannelReceiver = mpsc::Receiver<(Channel>, NonZeroU64)>; + +impl Hub { + /// Registers the given port on the hub. + /// + /// Randomizes port when given port is `0`. Returns [`None`] when given port + /// is already occupied. + fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> { + let mut hub = self.0.lock(); + + let port = if let Some(port) = NonZeroU64::new(port) { + port + } else { + loop { + let port = match NonZeroU64::new(rand::random()) { + Some(p) => p, + None => continue, + }; + if !hub.contains_key(&port) { + break port; + } + } + }; + + let (tx, rx) = mpsc::channel(2); + match hub.entry(port) { + Entry::Occupied(_) => return None, + Entry::Vacant(e) => e.insert(tx) + }; + + Some((rx, port)) + } + + fn unregister_port(&self, port: &NonZeroU64) -> Option { + self.0.lock().remove(port) + } + + fn get(&self, port: &NonZeroU64) -> Option { + self.0.lock().get(port).cloned() + } } /// Transport that supports `/memory/N` multiaddresses. @@ -38,15 +87,49 @@ pub struct MemoryTransport; /// Connection to a `MemoryTransport` currently being opened. pub struct DialFuture { - sender: mpsc::Sender>>, + /// Ephemeral source port. + /// + /// These ports mimic TCP ephemeral source ports but are not actually used + /// by the memory transport due to the direct use of channels. They merely + /// ensure that every connection has a unique address for each dialer, which + /// is not at the same time a listen address (analogous to TCP). + dial_port: NonZeroU64, + sender: ChannelSender, channel_to_send: Option>>, channel_to_return: Option>>, } +impl DialFuture { + fn new(port: NonZeroU64) -> Option { + let sender = HUB.get(&port)?.clone(); + + let (_dial_port_channel, dial_port) = HUB.register_port(0) + .expect("there to be some random unoccupied port."); + + let (a_tx, a_rx) = mpsc::channel(4096); + let (b_tx, b_rx) = mpsc::channel(4096); + Some(DialFuture { + dial_port, + sender, + channel_to_send: Some(RwStreamSink::new(Chan { + incoming: a_rx, + outgoing: b_tx, + dial_port: None, + })), + channel_to_return: Some(RwStreamSink::new(Chan { + incoming: b_rx, + outgoing: a_tx, + dial_port: Some(dial_port), + })), + }) + } +} + impl Future for DialFuture { type Output = Result>, MemoryTransportError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + match self.sender.poll_ready(cx) { Poll::Pending => return Poll::Pending, Poll::Ready(Ok(())) => {}, @@ -55,7 +138,8 @@ impl Future for DialFuture { let channel_to_send = self.channel_to_send.take() .expect("Future should not be polled again once complete"); - match self.sender.start_send(channel_to_send) { + let dial_port = self.dial_port; + match self.sender.start_send((channel_to_send, dial_port)) { Err(_) => return Poll::Ready(Err(MemoryTransportError::Unreachable)), Ok(()) => {} } @@ -79,28 +163,9 @@ impl Transport for MemoryTransport { return Err(TransportError::MultiaddrNotSupported(addr)); }; - let mut hub = (&*HUB).lock(); - - let port = if let Some(port) = NonZeroU64::new(port) { - port - } else { - loop { - let port = match NonZeroU64::new(rand::random()) { - Some(p) => p, - None => continue, - }; - if !hub.contains_key(&port) { - break port; - } - } - }; - - - let (tx, rx) = mpsc::channel(2); - match hub.entry(port) { - Entry::Occupied(_) => - return Err(TransportError::Other(MemoryTransportError::Unreachable)), - Entry::Vacant(e) => e.insert(tx) + let (rx, port) = match HUB.register_port(port) { + Some((rx, port)) => (rx, port), + None => return Err(TransportError::Other(MemoryTransportError::Unreachable)), }; let listener = Listener { @@ -124,19 +189,7 @@ impl Transport for MemoryTransport { return Err(TransportError::MultiaddrNotSupported(addr)); }; - let hub = HUB.lock(); - if let Some(sender) = hub.get(&port) { - let (a_tx, a_rx) = mpsc::channel(4096); - let (b_tx, b_rx) = mpsc::channel(4096); - Ok(DialFuture { - sender: sender.clone(), - channel_to_send: Some(RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx })), - channel_to_return: Some(RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx })), - - }) - } else { - Err(TransportError::Other(MemoryTransportError::Unreachable)) - } + DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable)) } } @@ -167,7 +220,7 @@ pub struct Listener { /// The address we are listening on. addr: Multiaddr, /// Receives incoming connections. - receiver: mpsc::Receiver>>, + receiver: ChannelReceiver, /// Generate `ListenerEvent::NewAddress` to inform about our listen address. tell_listen_addr: bool } @@ -181,7 +234,7 @@ impl Stream for Listener { return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))) } - let channel = match Stream::poll_next(Pin::new(&mut self.receiver), cx) { + let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) { Poll::Pending => return Poll::Pending, Poll::Ready(None) => panic!("Alive listeners always have a sender."), Poll::Ready(Some(v)) => v, @@ -190,7 +243,7 @@ impl Stream for Listener { let event = ListenerEvent::Upgrade { upgrade: future::ready(Ok(channel)), local_addr: self.addr.clone(), - remote_addr: Protocol::Memory(self.port.get()).into() + remote_addr: Protocol::Memory(dial_port.get()).into() }; Poll::Ready(Some(Ok(event))) @@ -199,7 +252,7 @@ impl Stream for Listener { impl Drop for Listener { fn drop(&mut self) { - let val_in = HUB.lock().remove(&self.port); + let val_in = HUB.unregister_port(&self.port); debug_assert!(val_in.is_some()); } } @@ -232,6 +285,14 @@ pub type Channel = RwStreamSink>; pub struct Chan> { incoming: mpsc::Receiver, outgoing: mpsc::Sender, + + // Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing + // port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and + // [`None`] when [`Chan`] of listener. + // + // Note: Listening port is unregistered in [`Drop`] implementation of + // [`Listener`]. + dial_port: Option, } impl Unpin for Chan { @@ -276,6 +337,15 @@ impl> Into>> for Chan { } } +impl Drop for Chan { + fn drop(&mut self) { + if let Some(port) = self.dial_port { + let channel_sender = HUB.unregister_port(&port); + debug_assert!(channel_sender.is_some()); + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -350,4 +420,90 @@ mod tests { futures::executor::block_on(futures::future::join(listener, dialer)); } + + #[test] + fn dialer_address_unequal_to_listener_address() { + let listener_addr: Multiaddr = Protocol::Memory( + rand::random::().saturating_add(1), + ).into(); + let listener_addr_cloned = listener_addr.clone(); + + let listener_transport = MemoryTransport::default(); + + let listener = async move { + let mut listener = listener_transport.listen_on(listener_addr.clone()) + .unwrap(); + while let Some(ev) = listener.next().await { + if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { + assert!( + remote_addr != listener_addr, + "Expect dialer address not to equal listener address." + ); + return; + } + } + }; + + let dialer = async move { + MemoryTransport::default().dial(listener_addr_cloned) + .unwrap() + .await + .unwrap(); + }; + + futures::executor::block_on(futures::future::join(listener, dialer)); + } + + #[test] + fn dialer_port_is_deregistered() { + let (terminate, should_terminate) = futures::channel::oneshot::channel(); + let (terminated, is_terminated) = futures::channel::oneshot::channel(); + + let listener_addr: Multiaddr = Protocol::Memory( + rand::random::().saturating_add(1), + ).into(); + let listener_addr_cloned = listener_addr.clone(); + + let listener_transport = MemoryTransport::default(); + + let listener = async move { + let mut listener = listener_transport.listen_on(listener_addr.clone()) + .unwrap(); + while let Some(ev) = listener.next().await { + if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() { + let dialer_port = NonZeroU64::new( + parse_memory_addr(&remote_addr).unwrap(), + ).unwrap(); + + assert!( + HUB.get(&dialer_port).is_some(), + "Expect dialer port to stay registered while connection is in use.", + ); + + terminate.send(()).unwrap(); + is_terminated.await.unwrap(); + + assert!( + HUB.get(&dialer_port).is_none(), + "Expect dialer port to be deregistered once connection is dropped.", + ); + + return; + } + } + }; + + let dialer = async move { + let _chan = MemoryTransport::default().dial(listener_addr_cloned) + .unwrap() + .await + .unwrap(); + + should_terminate.await.unwrap(); + drop(_chan); + terminated.send(()).unwrap(); + }; + + futures::executor::block_on(futures::future::join(listener, dialer)); + } }