core/transport/memory: Return dialer address in Upgrade event (#1724)

Previously a `Listener` would return its own address as the
`remote_addr` in the `ListenerEvent::Upgrade` event.

With this commit a `Listener` returns the dialer address as the
`remote_addr` in the `ListenerEvent::Upgrade` event. To do so a
`DialFuture` registers a port with the global `HUB` at construction
which is later on unregistered in the `Drop` implementation of the
dialer's `Chan`. The sending side of the `mpsc::channel` registered in
the `HUB` is dropped at `DialFuture` construction, thus one can not dial
the dialer port. This mimics the TCP transport behaviour preventing both
dialing and listening on the same TCP port.
This commit is contained in:
Max Inden
2020-09-08 12:07:15 +02:00
committed by GitHub
parent 6599ff13e1
commit 244c5aa87a
3 changed files with 204 additions and 44 deletions

View File

@ -16,6 +16,10 @@
two peer IDs are equal if and only if they use the same hash algorithm and
have the same hash digest. [PR 1608](https://github.com/libp2p/rust-libp2p/pull/1608).
- Return dialer address instead of listener address as `remote_addr` in
`MemoryTransport` `Listener` `ListenerEvent::Upgrade`
[PR 1724](https://github.com/libp2p/rust-libp2p/pull/1724).
# 0.21.0 [2020-08-18]
- Remove duplicates when performing address translation

View File

@ -400,7 +400,7 @@ mod tests {
match listeners.next().await.unwrap() {
ListenersEvent::Incoming { local_addr, send_back_addr, .. } => {
assert_eq!(local_addr, address);
assert_eq!(send_back_addr, address);
assert!(send_back_addr != address);
},
_ => panic!()
}

View File

@ -28,8 +28,57 @@ use rw_stream_sink::RwStreamSink;
use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin};
lazy_static! {
static ref HUB: Mutex<FnvHashMap<NonZeroU64, mpsc::Sender<Channel<Vec<u8>>>>> =
Mutex::new(FnvHashMap::default());
static ref HUB: Hub = Hub(Mutex::new(FnvHashMap::default()));
}
struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the
/// port of the dialer to a [`Listener`].
type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and
/// the port of the dialer from a [`DialFuture`].
type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
impl Hub {
/// Registers the given port on the hub.
///
/// Randomizes port when given port is `0`. Returns [`None`] when given port
/// is already occupied.
fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
let mut hub = self.0.lock();
let port = if let Some(port) = NonZeroU64::new(port) {
port
} else {
loop {
let port = match NonZeroU64::new(rand::random()) {
Some(p) => p,
None => continue,
};
if !hub.contains_key(&port) {
break port;
}
}
};
let (tx, rx) = mpsc::channel(2);
match hub.entry(port) {
Entry::Occupied(_) => return None,
Entry::Vacant(e) => e.insert(tx)
};
Some((rx, port))
}
fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
self.0.lock().remove(port)
}
fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
self.0.lock().get(port).cloned()
}
}
/// Transport that supports `/memory/N` multiaddresses.
@ -38,15 +87,49 @@ pub struct MemoryTransport;
/// Connection to a `MemoryTransport` currently being opened.
pub struct DialFuture {
sender: mpsc::Sender<Channel<Vec<u8>>>,
/// Ephemeral source port.
///
/// These ports mimic TCP ephemeral source ports but are not actually used
/// by the memory transport due to the direct use of channels. They merely
/// ensure that every connection has a unique address for each dialer, which
/// is not at the same time a listen address (analogous to TCP).
dial_port: NonZeroU64,
sender: ChannelSender,
channel_to_send: Option<Channel<Vec<u8>>>,
channel_to_return: Option<Channel<Vec<u8>>>,
}
impl DialFuture {
fn new(port: NonZeroU64) -> Option<Self> {
let sender = HUB.get(&port)?.clone();
let (_dial_port_channel, dial_port) = HUB.register_port(0)
.expect("there to be some random unoccupied port.");
let (a_tx, a_rx) = mpsc::channel(4096);
let (b_tx, b_rx) = mpsc::channel(4096);
Some(DialFuture {
dial_port,
sender,
channel_to_send: Some(RwStreamSink::new(Chan {
incoming: a_rx,
outgoing: b_tx,
dial_port: None,
})),
channel_to_return: Some(RwStreamSink::new(Chan {
incoming: b_rx,
outgoing: a_tx,
dial_port: Some(dial_port),
})),
})
}
}
impl Future for DialFuture {
type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.sender.poll_ready(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Ok(())) => {},
@ -55,7 +138,8 @@ impl Future for DialFuture {
let channel_to_send = self.channel_to_send.take()
.expect("Future should not be polled again once complete");
match self.sender.start_send(channel_to_send) {
let dial_port = self.dial_port;
match self.sender.start_send((channel_to_send, dial_port)) {
Err(_) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
Ok(()) => {}
}
@ -79,28 +163,9 @@ impl Transport for MemoryTransport {
return Err(TransportError::MultiaddrNotSupported(addr));
};
let mut hub = (&*HUB).lock();
let port = if let Some(port) = NonZeroU64::new(port) {
port
} else {
loop {
let port = match NonZeroU64::new(rand::random()) {
Some(p) => p,
None => continue,
};
if !hub.contains_key(&port) {
break port;
}
}
};
let (tx, rx) = mpsc::channel(2);
match hub.entry(port) {
Entry::Occupied(_) =>
return Err(TransportError::Other(MemoryTransportError::Unreachable)),
Entry::Vacant(e) => e.insert(tx)
let (rx, port) = match HUB.register_port(port) {
Some((rx, port)) => (rx, port),
None => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
};
let listener = Listener {
@ -124,19 +189,7 @@ impl Transport for MemoryTransport {
return Err(TransportError::MultiaddrNotSupported(addr));
};
let hub = HUB.lock();
if let Some(sender) = hub.get(&port) {
let (a_tx, a_rx) = mpsc::channel(4096);
let (b_tx, b_rx) = mpsc::channel(4096);
Ok(DialFuture {
sender: sender.clone(),
channel_to_send: Some(RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx })),
channel_to_return: Some(RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx })),
})
} else {
Err(TransportError::Other(MemoryTransportError::Unreachable))
}
DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
}
}
@ -167,7 +220,7 @@ pub struct Listener {
/// The address we are listening on.
addr: Multiaddr,
/// Receives incoming connections.
receiver: mpsc::Receiver<Channel<Vec<u8>>>,
receiver: ChannelReceiver,
/// Generate `ListenerEvent::NewAddress` to inform about our listen address.
tell_listen_addr: bool
}
@ -181,7 +234,7 @@ impl Stream for Listener {
return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone()))))
}
let channel = match Stream::poll_next(Pin::new(&mut self.receiver), cx) {
let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => panic!("Alive listeners always have a sender."),
Poll::Ready(Some(v)) => v,
@ -190,7 +243,7 @@ impl Stream for Listener {
let event = ListenerEvent::Upgrade {
upgrade: future::ready(Ok(channel)),
local_addr: self.addr.clone(),
remote_addr: Protocol::Memory(self.port.get()).into()
remote_addr: Protocol::Memory(dial_port.get()).into()
};
Poll::Ready(Some(Ok(event)))
@ -199,7 +252,7 @@ impl Stream for Listener {
impl Drop for Listener {
fn drop(&mut self) {
let val_in = HUB.lock().remove(&self.port);
let val_in = HUB.unregister_port(&self.port);
debug_assert!(val_in.is_some());
}
}
@ -232,6 +285,14 @@ pub type Channel<T> = RwStreamSink<Chan<T>>;
pub struct Chan<T = Vec<u8>> {
incoming: mpsc::Receiver<T>,
outgoing: mpsc::Sender<T>,
// Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing
// port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and
// [`None`] when [`Chan`] of listener.
//
// Note: Listening port is unregistered in [`Drop`] implementation of
// [`Listener`].
dial_port: Option<NonZeroU64>,
}
impl<T> Unpin for Chan<T> {
@ -276,6 +337,15 @@ impl<T: AsRef<[u8]>> Into<RwStreamSink<Chan<T>>> for Chan<T> {
}
}
impl<T> Drop for Chan<T> {
fn drop(&mut self) {
if let Some(port) = self.dial_port {
let channel_sender = HUB.unregister_port(&port);
debug_assert!(channel_sender.is_some());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -350,4 +420,90 @@ mod tests {
futures::executor::block_on(futures::future::join(listener, dialer));
}
#[test]
fn dialer_address_unequal_to_listener_address() {
let listener_addr: Multiaddr = Protocol::Memory(
rand::random::<u64>().saturating_add(1),
).into();
let listener_addr_cloned = listener_addr.clone();
let listener_transport = MemoryTransport::default();
let listener = async move {
let mut listener = listener_transport.listen_on(listener_addr.clone())
.unwrap();
while let Some(ev) = listener.next().await {
if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
assert!(
remote_addr != listener_addr,
"Expect dialer address not to equal listener address."
);
return;
}
}
};
let dialer = async move {
MemoryTransport::default().dial(listener_addr_cloned)
.unwrap()
.await
.unwrap();
};
futures::executor::block_on(futures::future::join(listener, dialer));
}
#[test]
fn dialer_port_is_deregistered() {
let (terminate, should_terminate) = futures::channel::oneshot::channel();
let (terminated, is_terminated) = futures::channel::oneshot::channel();
let listener_addr: Multiaddr = Protocol::Memory(
rand::random::<u64>().saturating_add(1),
).into();
let listener_addr_cloned = listener_addr.clone();
let listener_transport = MemoryTransport::default();
let listener = async move {
let mut listener = listener_transport.listen_on(listener_addr.clone())
.unwrap();
while let Some(ev) = listener.next().await {
if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
let dialer_port = NonZeroU64::new(
parse_memory_addr(&remote_addr).unwrap(),
).unwrap();
assert!(
HUB.get(&dialer_port).is_some(),
"Expect dialer port to stay registered while connection is in use.",
);
terminate.send(()).unwrap();
is_terminated.await.unwrap();
assert!(
HUB.get(&dialer_port).is_none(),
"Expect dialer port to be deregistered once connection is dropped.",
);
return;
}
}
};
let dialer = async move {
let _chan = MemoryTransport::default().dial(listener_addr_cloned)
.unwrap()
.await
.unwrap();
should_terminate.await.unwrap();
drop(_chan);
terminated.send(()).unwrap();
};
futures::executor::block_on(futures::future::join(listener, dialer));
}
}