core/transport/memory: Return dialer address in Upgrade event (#1724)

Previously a `Listener` would return its own address as the
`remote_addr` in the `ListenerEvent::Upgrade` event.

With this commit a `Listener` returns the dialer address as the
`remote_addr` in the `ListenerEvent::Upgrade` event. To do so a
`DialFuture` registers a port with the global `HUB` at construction
which is later on unregistered in the `Drop` implementation of the
dialer's `Chan`. The sending side of the `mpsc::channel` registered in
the `HUB` is dropped at `DialFuture` construction, thus one can not dial
the dialer port. This mimics the TCP transport behaviour preventing both
dialing and listening on the same TCP port.
This commit is contained in:
Max Inden
2020-09-08 12:07:15 +02:00
committed by GitHub
parent 6599ff13e1
commit 244c5aa87a
3 changed files with 204 additions and 44 deletions

View File

@ -16,6 +16,10 @@
two peer IDs are equal if and only if they use the same hash algorithm and two peer IDs are equal if and only if they use the same hash algorithm and
have the same hash digest. [PR 1608](https://github.com/libp2p/rust-libp2p/pull/1608). have the same hash digest. [PR 1608](https://github.com/libp2p/rust-libp2p/pull/1608).
- Return dialer address instead of listener address as `remote_addr` in
`MemoryTransport` `Listener` `ListenerEvent::Upgrade`
[PR 1724](https://github.com/libp2p/rust-libp2p/pull/1724).
# 0.21.0 [2020-08-18] # 0.21.0 [2020-08-18]
- Remove duplicates when performing address translation - Remove duplicates when performing address translation

View File

@ -400,7 +400,7 @@ mod tests {
match listeners.next().await.unwrap() { match listeners.next().await.unwrap() {
ListenersEvent::Incoming { local_addr, send_back_addr, .. } => { ListenersEvent::Incoming { local_addr, send_back_addr, .. } => {
assert_eq!(local_addr, address); assert_eq!(local_addr, address);
assert_eq!(send_back_addr, address); assert!(send_back_addr != address);
}, },
_ => panic!() _ => panic!()
} }

View File

@ -28,8 +28,57 @@ use rw_stream_sink::RwStreamSink;
use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin}; use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin};
lazy_static! { lazy_static! {
static ref HUB: Mutex<FnvHashMap<NonZeroU64, mpsc::Sender<Channel<Vec<u8>>>>> = static ref HUB: Hub = Hub(Mutex::new(FnvHashMap::default()));
Mutex::new(FnvHashMap::default()); }
struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the
/// port of the dialer to a [`Listener`].
type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and
/// the port of the dialer from a [`DialFuture`].
type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
impl Hub {
/// Registers the given port on the hub.
///
/// Randomizes port when given port is `0`. Returns [`None`] when given port
/// is already occupied.
fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
let mut hub = self.0.lock();
let port = if let Some(port) = NonZeroU64::new(port) {
port
} else {
loop {
let port = match NonZeroU64::new(rand::random()) {
Some(p) => p,
None => continue,
};
if !hub.contains_key(&port) {
break port;
}
}
};
let (tx, rx) = mpsc::channel(2);
match hub.entry(port) {
Entry::Occupied(_) => return None,
Entry::Vacant(e) => e.insert(tx)
};
Some((rx, port))
}
fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
self.0.lock().remove(port)
}
fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
self.0.lock().get(port).cloned()
}
} }
/// Transport that supports `/memory/N` multiaddresses. /// Transport that supports `/memory/N` multiaddresses.
@ -38,15 +87,49 @@ pub struct MemoryTransport;
/// Connection to a `MemoryTransport` currently being opened. /// Connection to a `MemoryTransport` currently being opened.
pub struct DialFuture { pub struct DialFuture {
sender: mpsc::Sender<Channel<Vec<u8>>>, /// Ephemeral source port.
///
/// These ports mimic TCP ephemeral source ports but are not actually used
/// by the memory transport due to the direct use of channels. They merely
/// ensure that every connection has a unique address for each dialer, which
/// is not at the same time a listen address (analogous to TCP).
dial_port: NonZeroU64,
sender: ChannelSender,
channel_to_send: Option<Channel<Vec<u8>>>, channel_to_send: Option<Channel<Vec<u8>>>,
channel_to_return: Option<Channel<Vec<u8>>>, channel_to_return: Option<Channel<Vec<u8>>>,
} }
impl DialFuture {
fn new(port: NonZeroU64) -> Option<Self> {
let sender = HUB.get(&port)?.clone();
let (_dial_port_channel, dial_port) = HUB.register_port(0)
.expect("there to be some random unoccupied port.");
let (a_tx, a_rx) = mpsc::channel(4096);
let (b_tx, b_rx) = mpsc::channel(4096);
Some(DialFuture {
dial_port,
sender,
channel_to_send: Some(RwStreamSink::new(Chan {
incoming: a_rx,
outgoing: b_tx,
dial_port: None,
})),
channel_to_return: Some(RwStreamSink::new(Chan {
incoming: b_rx,
outgoing: a_tx,
dial_port: Some(dial_port),
})),
})
}
}
impl Future for DialFuture { impl Future for DialFuture {
type Output = Result<Channel<Vec<u8>>, MemoryTransportError>; type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.sender.poll_ready(cx) { match self.sender.poll_ready(cx) {
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => {}, Poll::Ready(Ok(())) => {},
@ -55,7 +138,8 @@ impl Future for DialFuture {
let channel_to_send = self.channel_to_send.take() let channel_to_send = self.channel_to_send.take()
.expect("Future should not be polled again once complete"); .expect("Future should not be polled again once complete");
match self.sender.start_send(channel_to_send) { let dial_port = self.dial_port;
match self.sender.start_send((channel_to_send, dial_port)) {
Err(_) => return Poll::Ready(Err(MemoryTransportError::Unreachable)), Err(_) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
Ok(()) => {} Ok(()) => {}
} }
@ -79,28 +163,9 @@ impl Transport for MemoryTransport {
return Err(TransportError::MultiaddrNotSupported(addr)); return Err(TransportError::MultiaddrNotSupported(addr));
}; };
let mut hub = (&*HUB).lock(); let (rx, port) = match HUB.register_port(port) {
Some((rx, port)) => (rx, port),
let port = if let Some(port) = NonZeroU64::new(port) { None => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
port
} else {
loop {
let port = match NonZeroU64::new(rand::random()) {
Some(p) => p,
None => continue,
};
if !hub.contains_key(&port) {
break port;
}
}
};
let (tx, rx) = mpsc::channel(2);
match hub.entry(port) {
Entry::Occupied(_) =>
return Err(TransportError::Other(MemoryTransportError::Unreachable)),
Entry::Vacant(e) => e.insert(tx)
}; };
let listener = Listener { let listener = Listener {
@ -124,19 +189,7 @@ impl Transport for MemoryTransport {
return Err(TransportError::MultiaddrNotSupported(addr)); return Err(TransportError::MultiaddrNotSupported(addr));
}; };
let hub = HUB.lock(); DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
if let Some(sender) = hub.get(&port) {
let (a_tx, a_rx) = mpsc::channel(4096);
let (b_tx, b_rx) = mpsc::channel(4096);
Ok(DialFuture {
sender: sender.clone(),
channel_to_send: Some(RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx })),
channel_to_return: Some(RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx })),
})
} else {
Err(TransportError::Other(MemoryTransportError::Unreachable))
}
} }
} }
@ -167,7 +220,7 @@ pub struct Listener {
/// The address we are listening on. /// The address we are listening on.
addr: Multiaddr, addr: Multiaddr,
/// Receives incoming connections. /// Receives incoming connections.
receiver: mpsc::Receiver<Channel<Vec<u8>>>, receiver: ChannelReceiver,
/// Generate `ListenerEvent::NewAddress` to inform about our listen address. /// Generate `ListenerEvent::NewAddress` to inform about our listen address.
tell_listen_addr: bool tell_listen_addr: bool
} }
@ -181,7 +234,7 @@ impl Stream for Listener {
return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone())))) return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone()))))
} }
let channel = match Stream::poll_next(Pin::new(&mut self.receiver), cx) { let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) {
Poll::Pending => return Poll::Pending, Poll::Pending => return Poll::Pending,
Poll::Ready(None) => panic!("Alive listeners always have a sender."), Poll::Ready(None) => panic!("Alive listeners always have a sender."),
Poll::Ready(Some(v)) => v, Poll::Ready(Some(v)) => v,
@ -190,7 +243,7 @@ impl Stream for Listener {
let event = ListenerEvent::Upgrade { let event = ListenerEvent::Upgrade {
upgrade: future::ready(Ok(channel)), upgrade: future::ready(Ok(channel)),
local_addr: self.addr.clone(), local_addr: self.addr.clone(),
remote_addr: Protocol::Memory(self.port.get()).into() remote_addr: Protocol::Memory(dial_port.get()).into()
}; };
Poll::Ready(Some(Ok(event))) Poll::Ready(Some(Ok(event)))
@ -199,7 +252,7 @@ impl Stream for Listener {
impl Drop for Listener { impl Drop for Listener {
fn drop(&mut self) { fn drop(&mut self) {
let val_in = HUB.lock().remove(&self.port); let val_in = HUB.unregister_port(&self.port);
debug_assert!(val_in.is_some()); debug_assert!(val_in.is_some());
} }
} }
@ -232,6 +285,14 @@ pub type Channel<T> = RwStreamSink<Chan<T>>;
pub struct Chan<T = Vec<u8>> { pub struct Chan<T = Vec<u8>> {
incoming: mpsc::Receiver<T>, incoming: mpsc::Receiver<T>,
outgoing: mpsc::Sender<T>, outgoing: mpsc::Sender<T>,
// Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing
// port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and
// [`None`] when [`Chan`] of listener.
//
// Note: Listening port is unregistered in [`Drop`] implementation of
// [`Listener`].
dial_port: Option<NonZeroU64>,
} }
impl<T> Unpin for Chan<T> { impl<T> Unpin for Chan<T> {
@ -276,6 +337,15 @@ impl<T: AsRef<[u8]>> Into<RwStreamSink<Chan<T>>> for Chan<T> {
} }
} }
impl<T> Drop for Chan<T> {
fn drop(&mut self) {
if let Some(port) = self.dial_port {
let channel_sender = HUB.unregister_port(&port);
debug_assert!(channel_sender.is_some());
}
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -350,4 +420,90 @@ mod tests {
futures::executor::block_on(futures::future::join(listener, dialer)); futures::executor::block_on(futures::future::join(listener, dialer));
} }
#[test]
fn dialer_address_unequal_to_listener_address() {
let listener_addr: Multiaddr = Protocol::Memory(
rand::random::<u64>().saturating_add(1),
).into();
let listener_addr_cloned = listener_addr.clone();
let listener_transport = MemoryTransport::default();
let listener = async move {
let mut listener = listener_transport.listen_on(listener_addr.clone())
.unwrap();
while let Some(ev) = listener.next().await {
if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
assert!(
remote_addr != listener_addr,
"Expect dialer address not to equal listener address."
);
return;
}
}
};
let dialer = async move {
MemoryTransport::default().dial(listener_addr_cloned)
.unwrap()
.await
.unwrap();
};
futures::executor::block_on(futures::future::join(listener, dialer));
}
#[test]
fn dialer_port_is_deregistered() {
let (terminate, should_terminate) = futures::channel::oneshot::channel();
let (terminated, is_terminated) = futures::channel::oneshot::channel();
let listener_addr: Multiaddr = Protocol::Memory(
rand::random::<u64>().saturating_add(1),
).into();
let listener_addr_cloned = listener_addr.clone();
let listener_transport = MemoryTransport::default();
let listener = async move {
let mut listener = listener_transport.listen_on(listener_addr.clone())
.unwrap();
while let Some(ev) = listener.next().await {
if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
let dialer_port = NonZeroU64::new(
parse_memory_addr(&remote_addr).unwrap(),
).unwrap();
assert!(
HUB.get(&dialer_port).is_some(),
"Expect dialer port to stay registered while connection is in use.",
);
terminate.send(()).unwrap();
is_terminated.await.unwrap();
assert!(
HUB.get(&dialer_port).is_none(),
"Expect dialer port to be deregistered once connection is dropped.",
);
return;
}
}
};
let dialer = async move {
let _chan = MemoryTransport::default().dial(listener_addr_cloned)
.unwrap()
.await
.unwrap();
should_terminate.await.unwrap();
drop(_chan);
terminated.send(()).unwrap();
};
futures::executor::block_on(futures::future::join(listener, dialer));
}
} }