mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-06-27 08:41:36 +00:00
core/transport/memory: Return dialer address in Upgrade event (#1724)
Previously a `Listener` would return its own address as the `remote_addr` in the `ListenerEvent::Upgrade` event. With this commit a `Listener` returns the dialer address as the `remote_addr` in the `ListenerEvent::Upgrade` event. To do so a `DialFuture` registers a port with the global `HUB` at construction which is later on unregistered in the `Drop` implementation of the dialer's `Chan`. The sending side of the `mpsc::channel` registered in the `HUB` is dropped at `DialFuture` construction, thus one can not dial the dialer port. This mimics the TCP transport behaviour preventing both dialing and listening on the same TCP port.
This commit is contained in:
@ -400,7 +400,7 @@ mod tests {
|
||||
match listeners.next().await.unwrap() {
|
||||
ListenersEvent::Incoming { local_addr, send_back_addr, .. } => {
|
||||
assert_eq!(local_addr, address);
|
||||
assert_eq!(send_back_addr, address);
|
||||
assert!(send_back_addr != address);
|
||||
},
|
||||
_ => panic!()
|
||||
}
|
||||
|
@ -28,8 +28,57 @@ use rw_stream_sink::RwStreamSink;
|
||||
use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin};
|
||||
|
||||
lazy_static! {
|
||||
static ref HUB: Mutex<FnvHashMap<NonZeroU64, mpsc::Sender<Channel<Vec<u8>>>>> =
|
||||
Mutex::new(FnvHashMap::default());
|
||||
static ref HUB: Hub = Hub(Mutex::new(FnvHashMap::default()));
|
||||
}
|
||||
|
||||
struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
|
||||
|
||||
/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the
|
||||
/// port of the dialer to a [`Listener`].
|
||||
type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
|
||||
|
||||
/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and
|
||||
/// the port of the dialer from a [`DialFuture`].
|
||||
type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
|
||||
|
||||
impl Hub {
|
||||
/// Registers the given port on the hub.
|
||||
///
|
||||
/// Randomizes port when given port is `0`. Returns [`None`] when given port
|
||||
/// is already occupied.
|
||||
fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
|
||||
let mut hub = self.0.lock();
|
||||
|
||||
let port = if let Some(port) = NonZeroU64::new(port) {
|
||||
port
|
||||
} else {
|
||||
loop {
|
||||
let port = match NonZeroU64::new(rand::random()) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
if !hub.contains_key(&port) {
|
||||
break port;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let (tx, rx) = mpsc::channel(2);
|
||||
match hub.entry(port) {
|
||||
Entry::Occupied(_) => return None,
|
||||
Entry::Vacant(e) => e.insert(tx)
|
||||
};
|
||||
|
||||
Some((rx, port))
|
||||
}
|
||||
|
||||
fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
|
||||
self.0.lock().remove(port)
|
||||
}
|
||||
|
||||
fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
|
||||
self.0.lock().get(port).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
/// Transport that supports `/memory/N` multiaddresses.
|
||||
@ -38,15 +87,49 @@ pub struct MemoryTransport;
|
||||
|
||||
/// Connection to a `MemoryTransport` currently being opened.
|
||||
pub struct DialFuture {
|
||||
sender: mpsc::Sender<Channel<Vec<u8>>>,
|
||||
/// Ephemeral source port.
|
||||
///
|
||||
/// These ports mimic TCP ephemeral source ports but are not actually used
|
||||
/// by the memory transport due to the direct use of channels. They merely
|
||||
/// ensure that every connection has a unique address for each dialer, which
|
||||
/// is not at the same time a listen address (analogous to TCP).
|
||||
dial_port: NonZeroU64,
|
||||
sender: ChannelSender,
|
||||
channel_to_send: Option<Channel<Vec<u8>>>,
|
||||
channel_to_return: Option<Channel<Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl DialFuture {
|
||||
fn new(port: NonZeroU64) -> Option<Self> {
|
||||
let sender = HUB.get(&port)?.clone();
|
||||
|
||||
let (_dial_port_channel, dial_port) = HUB.register_port(0)
|
||||
.expect("there to be some random unoccupied port.");
|
||||
|
||||
let (a_tx, a_rx) = mpsc::channel(4096);
|
||||
let (b_tx, b_rx) = mpsc::channel(4096);
|
||||
Some(DialFuture {
|
||||
dial_port,
|
||||
sender,
|
||||
channel_to_send: Some(RwStreamSink::new(Chan {
|
||||
incoming: a_rx,
|
||||
outgoing: b_tx,
|
||||
dial_port: None,
|
||||
})),
|
||||
channel_to_return: Some(RwStreamSink::new(Chan {
|
||||
incoming: b_rx,
|
||||
outgoing: a_tx,
|
||||
dial_port: Some(dial_port),
|
||||
})),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for DialFuture {
|
||||
type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
|
||||
match self.sender.poll_ready(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Ok(())) => {},
|
||||
@ -55,7 +138,8 @@ impl Future for DialFuture {
|
||||
|
||||
let channel_to_send = self.channel_to_send.take()
|
||||
.expect("Future should not be polled again once complete");
|
||||
match self.sender.start_send(channel_to_send) {
|
||||
let dial_port = self.dial_port;
|
||||
match self.sender.start_send((channel_to_send, dial_port)) {
|
||||
Err(_) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
|
||||
Ok(()) => {}
|
||||
}
|
||||
@ -79,28 +163,9 @@ impl Transport for MemoryTransport {
|
||||
return Err(TransportError::MultiaddrNotSupported(addr));
|
||||
};
|
||||
|
||||
let mut hub = (&*HUB).lock();
|
||||
|
||||
let port = if let Some(port) = NonZeroU64::new(port) {
|
||||
port
|
||||
} else {
|
||||
loop {
|
||||
let port = match NonZeroU64::new(rand::random()) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
if !hub.contains_key(&port) {
|
||||
break port;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
let (tx, rx) = mpsc::channel(2);
|
||||
match hub.entry(port) {
|
||||
Entry::Occupied(_) =>
|
||||
return Err(TransportError::Other(MemoryTransportError::Unreachable)),
|
||||
Entry::Vacant(e) => e.insert(tx)
|
||||
let (rx, port) = match HUB.register_port(port) {
|
||||
Some((rx, port)) => (rx, port),
|
||||
None => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
|
||||
};
|
||||
|
||||
let listener = Listener {
|
||||
@ -124,19 +189,7 @@ impl Transport for MemoryTransport {
|
||||
return Err(TransportError::MultiaddrNotSupported(addr));
|
||||
};
|
||||
|
||||
let hub = HUB.lock();
|
||||
if let Some(sender) = hub.get(&port) {
|
||||
let (a_tx, a_rx) = mpsc::channel(4096);
|
||||
let (b_tx, b_rx) = mpsc::channel(4096);
|
||||
Ok(DialFuture {
|
||||
sender: sender.clone(),
|
||||
channel_to_send: Some(RwStreamSink::new(Chan { incoming: a_rx, outgoing: b_tx })),
|
||||
channel_to_return: Some(RwStreamSink::new(Chan { incoming: b_rx, outgoing: a_tx })),
|
||||
|
||||
})
|
||||
} else {
|
||||
Err(TransportError::Other(MemoryTransportError::Unreachable))
|
||||
}
|
||||
DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
|
||||
}
|
||||
}
|
||||
|
||||
@ -167,7 +220,7 @@ pub struct Listener {
|
||||
/// The address we are listening on.
|
||||
addr: Multiaddr,
|
||||
/// Receives incoming connections.
|
||||
receiver: mpsc::Receiver<Channel<Vec<u8>>>,
|
||||
receiver: ChannelReceiver,
|
||||
/// Generate `ListenerEvent::NewAddress` to inform about our listen address.
|
||||
tell_listen_addr: bool
|
||||
}
|
||||
@ -181,7 +234,7 @@ impl Stream for Listener {
|
||||
return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone()))))
|
||||
}
|
||||
|
||||
let channel = match Stream::poll_next(Pin::new(&mut self.receiver), cx) {
|
||||
let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(None) => panic!("Alive listeners always have a sender."),
|
||||
Poll::Ready(Some(v)) => v,
|
||||
@ -190,7 +243,7 @@ impl Stream for Listener {
|
||||
let event = ListenerEvent::Upgrade {
|
||||
upgrade: future::ready(Ok(channel)),
|
||||
local_addr: self.addr.clone(),
|
||||
remote_addr: Protocol::Memory(self.port.get()).into()
|
||||
remote_addr: Protocol::Memory(dial_port.get()).into()
|
||||
};
|
||||
|
||||
Poll::Ready(Some(Ok(event)))
|
||||
@ -199,7 +252,7 @@ impl Stream for Listener {
|
||||
|
||||
impl Drop for Listener {
|
||||
fn drop(&mut self) {
|
||||
let val_in = HUB.lock().remove(&self.port);
|
||||
let val_in = HUB.unregister_port(&self.port);
|
||||
debug_assert!(val_in.is_some());
|
||||
}
|
||||
}
|
||||
@ -232,6 +285,14 @@ pub type Channel<T> = RwStreamSink<Chan<T>>;
|
||||
pub struct Chan<T = Vec<u8>> {
|
||||
incoming: mpsc::Receiver<T>,
|
||||
outgoing: mpsc::Sender<T>,
|
||||
|
||||
// Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing
|
||||
// port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and
|
||||
// [`None`] when [`Chan`] of listener.
|
||||
//
|
||||
// Note: Listening port is unregistered in [`Drop`] implementation of
|
||||
// [`Listener`].
|
||||
dial_port: Option<NonZeroU64>,
|
||||
}
|
||||
|
||||
impl<T> Unpin for Chan<T> {
|
||||
@ -276,6 +337,15 @@ impl<T: AsRef<[u8]>> Into<RwStreamSink<Chan<T>>> for Chan<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for Chan<T> {
|
||||
fn drop(&mut self) {
|
||||
if let Some(port) = self.dial_port {
|
||||
let channel_sender = HUB.unregister_port(&port);
|
||||
debug_assert!(channel_sender.is_some());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -350,4 +420,90 @@ mod tests {
|
||||
|
||||
futures::executor::block_on(futures::future::join(listener, dialer));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dialer_address_unequal_to_listener_address() {
|
||||
let listener_addr: Multiaddr = Protocol::Memory(
|
||||
rand::random::<u64>().saturating_add(1),
|
||||
).into();
|
||||
let listener_addr_cloned = listener_addr.clone();
|
||||
|
||||
let listener_transport = MemoryTransport::default();
|
||||
|
||||
let listener = async move {
|
||||
let mut listener = listener_transport.listen_on(listener_addr.clone())
|
||||
.unwrap();
|
||||
while let Some(ev) = listener.next().await {
|
||||
if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
|
||||
assert!(
|
||||
remote_addr != listener_addr,
|
||||
"Expect dialer address not to equal listener address."
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let dialer = async move {
|
||||
MemoryTransport::default().dial(listener_addr_cloned)
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
};
|
||||
|
||||
futures::executor::block_on(futures::future::join(listener, dialer));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dialer_port_is_deregistered() {
|
||||
let (terminate, should_terminate) = futures::channel::oneshot::channel();
|
||||
let (terminated, is_terminated) = futures::channel::oneshot::channel();
|
||||
|
||||
let listener_addr: Multiaddr = Protocol::Memory(
|
||||
rand::random::<u64>().saturating_add(1),
|
||||
).into();
|
||||
let listener_addr_cloned = listener_addr.clone();
|
||||
|
||||
let listener_transport = MemoryTransport::default();
|
||||
|
||||
let listener = async move {
|
||||
let mut listener = listener_transport.listen_on(listener_addr.clone())
|
||||
.unwrap();
|
||||
while let Some(ev) = listener.next().await {
|
||||
if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
|
||||
let dialer_port = NonZeroU64::new(
|
||||
parse_memory_addr(&remote_addr).unwrap(),
|
||||
).unwrap();
|
||||
|
||||
assert!(
|
||||
HUB.get(&dialer_port).is_some(),
|
||||
"Expect dialer port to stay registered while connection is in use.",
|
||||
);
|
||||
|
||||
terminate.send(()).unwrap();
|
||||
is_terminated.await.unwrap();
|
||||
|
||||
assert!(
|
||||
HUB.get(&dialer_port).is_none(),
|
||||
"Expect dialer port to be deregistered once connection is dropped.",
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let dialer = async move {
|
||||
let _chan = MemoryTransport::default().dial(listener_addr_cloned)
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
should_terminate.await.unwrap();
|
||||
drop(_chan);
|
||||
terminated.send(()).unwrap();
|
||||
};
|
||||
|
||||
futures::executor::block_on(futures::future::join(listener, dialer));
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user