mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-05-29 18:51:22 +00:00
fix(quic): Trigger Quic as Transport wakeup on dial (#3306)
Scenario: rust-libp2p node A dials rust-libp2p node B. B listens on a QUIC address. A dials B via the `libp2p-quic` `Transport` wrapped in a `libp2p-dns` `Transport`. Note that `libp2p-dns` in itself is not relevant here. Only the fact that `libp2p-dns` delays a dial is relevant, i.e. that it first does other async stuff (DNS lookup) before creating the QUIC dial. In fact, dialing an IP address through the DNS `Transport` where no DNS resolution is needed triggers the below just fine. 1. A calls `Swarm::dial` which creates a `libp2p-dns` dial. 2. That dial is spawned onto the connection `Pool`, thus starting the DNS resolution. 3. A continuously calls `Swarm::poll`. 4. `libp2p-quic` `Transport::poll` is called, finding no dialers in `self.dialer` given that the spawned dial is still only resolving the DNS address. 5. On the spawned connection task: 1. The DNS resolution finishes. 2. Thus calling `Transport::dial` on `libp1p-quic` (note that the DNS dial has a clone of the QUIC `Transport` wrapped in an `Arc<Mutex<_>>`). 3. That adds a dialer to `self.dialer`. Note that there are no listeners, i.e. `Swarm::listen_on` was never called. 4. `DialerState::new_dial` is called which adds a message to `self.pending_dials` and wakes `self.waker`. Given that on the last `Transport::poll` there was no `self.dialer`, that waker is empty. Result: The message is stuck in the `DialerState::pending_dials`. The message is never send to the endpoint driver. The dial never succeeds. This commit fixes the above, waking the `<Quic as Transport>:poll` method.
This commit is contained in:
parent
1b6c915813
commit
7665e74cdb
@ -71,6 +71,7 @@ pub struct GenTransport<P: Provider> {
|
|||||||
listeners: SelectAll<Listener<P>>,
|
listeners: SelectAll<Listener<P>>,
|
||||||
/// Dialer for each socket family if no matching listener exists.
|
/// Dialer for each socket family if no matching listener exists.
|
||||||
dialer: HashMap<SocketFamily, Dialer>,
|
dialer: HashMap<SocketFamily, Dialer>,
|
||||||
|
dialer_waker: Option<Waker>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<P: Provider> GenTransport<P> {
|
impl<P: Provider> GenTransport<P> {
|
||||||
@ -84,6 +85,7 @@ impl<P: Provider> GenTransport<P> {
|
|||||||
quinn_config,
|
quinn_config,
|
||||||
handshake_timeout,
|
handshake_timeout,
|
||||||
dialer: HashMap::new(),
|
dialer: HashMap::new(),
|
||||||
|
dialer_waker: None,
|
||||||
support_draft_29,
|
support_draft_29,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -178,6 +180,12 @@ impl<P: Provider> Transport for GenTransport<P> {
|
|||||||
&mut listeners[index].dialer_state
|
&mut listeners[index].dialer_state
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Wakeup the task polling [`Transport::poll`] to drive the new dial.
|
||||||
|
if let Some(waker) = self.dialer_waker.take() {
|
||||||
|
waker.wake();
|
||||||
|
}
|
||||||
|
|
||||||
Ok(dialer_state.new_dial(socket_addr, self.handshake_timeout, version))
|
Ok(dialer_state.new_dial(socket_addr, self.handshake_timeout, version))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -207,10 +215,14 @@ impl<P: Provider> Transport for GenTransport<P> {
|
|||||||
// Drop dialer and all pending dials so that the connection receiver is notified.
|
// Drop dialer and all pending dials so that the connection receiver is notified.
|
||||||
self.dialer.remove(&key);
|
self.dialer.remove(&key);
|
||||||
}
|
}
|
||||||
match self.listeners.poll_next_unpin(cx) {
|
|
||||||
Poll::Ready(Some(ev)) => Poll::Ready(ev),
|
if let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) {
|
||||||
_ => Poll::Pending,
|
return Poll::Ready(ev);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.dialer_waker = Some(cx.waker().clone());
|
||||||
|
|
||||||
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -254,12 +266,11 @@ impl Drop for Dialer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pending dials to be sent to the endpoint was the [`endpoint::Channel`]
|
/// Pending dials to be sent to the endpoint once the [`endpoint::Channel`]
|
||||||
/// has capacity
|
/// has capacity.
|
||||||
#[derive(Default, Debug)]
|
#[derive(Default, Debug)]
|
||||||
struct DialerState {
|
struct DialerState {
|
||||||
pending_dials: VecDeque<ToEndpoint>,
|
pending_dials: VecDeque<ToEndpoint>,
|
||||||
waker: Option<Waker>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DialerState {
|
impl DialerState {
|
||||||
@ -279,10 +290,6 @@ impl DialerState {
|
|||||||
|
|
||||||
self.pending_dials.push_back(message);
|
self.pending_dials.push_back(message);
|
||||||
|
|
||||||
if let Some(waker) = self.waker.take() {
|
|
||||||
waker.wake();
|
|
||||||
}
|
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
// Our oneshot getting dropped means the message didn't make it to the endpoint driver.
|
// Our oneshot getting dropped means the message didn't make it to the endpoint driver.
|
||||||
let connection = tx.await.map_err(|_| Error::EndpointDriverCrashed)??;
|
let connection = tx.await.map_err(|_| Error::EndpointDriverCrashed)??;
|
||||||
@ -307,7 +314,6 @@ impl DialerState {
|
|||||||
Err(endpoint::Disconnected {}) => return Poll::Ready(Error::EndpointDriverCrashed),
|
Err(endpoint::Disconnected {}) => return Poll::Ready(Error::EndpointDriverCrashed),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
self.waker = Some(cx.waker().clone());
|
|
||||||
Poll::Pending
|
Poll::Pending
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,12 +1,15 @@
|
|||||||
#![cfg(any(feature = "async-std", feature = "tokio"))]
|
#![cfg(any(feature = "async-std", feature = "tokio"))]
|
||||||
|
|
||||||
use futures::channel::{mpsc, oneshot};
|
use futures::channel::{mpsc, oneshot};
|
||||||
|
use futures::future::BoxFuture;
|
||||||
use futures::future::{poll_fn, Either};
|
use futures::future::{poll_fn, Either};
|
||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt};
|
use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt};
|
||||||
|
use futures_timer::Delay;
|
||||||
use libp2p_core::either::EitherOutput;
|
use libp2p_core::either::EitherOutput;
|
||||||
use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt, SubstreamBox};
|
use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt, SubstreamBox};
|
||||||
use libp2p_core::transport::{Boxed, OrTransport, TransportEvent};
|
use libp2p_core::transport::{Boxed, OrTransport, TransportEvent};
|
||||||
|
use libp2p_core::transport::{ListenerId, TransportError};
|
||||||
use libp2p_core::{multiaddr::Protocol, upgrade, Multiaddr, PeerId, Transport};
|
use libp2p_core::{multiaddr::Protocol, upgrade, Multiaddr, PeerId, Transport};
|
||||||
use libp2p_noise as noise;
|
use libp2p_noise as noise;
|
||||||
use libp2p_quic as quic;
|
use libp2p_quic as quic;
|
||||||
@ -19,6 +22,10 @@ use std::io;
|
|||||||
use std::num::NonZeroU8;
|
use std::num::NonZeroU8;
|
||||||
use std::task::Poll;
|
use std::task::Poll;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use std::{
|
||||||
|
pin::Pin,
|
||||||
|
sync::{Arc, Mutex},
|
||||||
|
};
|
||||||
|
|
||||||
#[cfg(feature = "tokio")]
|
#[cfg(feature = "tokio")]
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
@ -90,6 +97,113 @@ async fn ipv4_dial_ipv6() {
|
|||||||
assert_eq!(b_connected, a_peer_id);
|
assert_eq!(b_connected, a_peer_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Tests that a [`Transport::dial`] wakes up the task previously polling [`Transport::poll`].
|
||||||
|
///
|
||||||
|
/// See https://github.com/libp2p/rust-libp2p/pull/3306 for context.
|
||||||
|
#[cfg(feature = "async-std")]
|
||||||
|
#[async_std::test]
|
||||||
|
async fn wrapped_with_dns() {
|
||||||
|
let _ = env_logger::try_init();
|
||||||
|
|
||||||
|
struct DialDelay(Arc<Mutex<Boxed<(PeerId, StreamMuxerBox)>>>);
|
||||||
|
|
||||||
|
impl Transport for DialDelay {
|
||||||
|
type Output = (PeerId, StreamMuxerBox);
|
||||||
|
type Error = std::io::Error;
|
||||||
|
type ListenerUpgrade = Pin<Box<dyn Future<Output = io::Result<Self::Output>> + Send>>;
|
||||||
|
type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
|
||||||
|
|
||||||
|
fn listen_on(
|
||||||
|
&mut self,
|
||||||
|
addr: Multiaddr,
|
||||||
|
) -> Result<ListenerId, TransportError<Self::Error>> {
|
||||||
|
self.0.lock().unwrap().listen_on(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_listener(&mut self, id: ListenerId) -> bool {
|
||||||
|
self.0.lock().unwrap().remove_listener(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn address_translation(
|
||||||
|
&self,
|
||||||
|
listen: &Multiaddr,
|
||||||
|
observed: &Multiaddr,
|
||||||
|
) -> Option<Multiaddr> {
|
||||||
|
self.0.lock().unwrap().address_translation(listen, observed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delayed dial, i.e. calling [`Transport::dial`] on the inner [`Transport`] not within the
|
||||||
|
/// synchronous [`Transport::dial`] method, but within the [`Future`] returned by the outer
|
||||||
|
/// [`Transport::dial`].
|
||||||
|
fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
|
||||||
|
let t = self.0.clone();
|
||||||
|
Ok(async move {
|
||||||
|
// Simulate DNS lookup. Giving the `Transport::poll` the chance to return
|
||||||
|
// `Poll::Pending` and thus suspending its task, waiting for a wakeup from the dial
|
||||||
|
// on the inner transport below.
|
||||||
|
Delay::new(Duration::from_millis(100)).await;
|
||||||
|
|
||||||
|
let dial = t.lock().unwrap().dial(addr).map_err(|e| match e {
|
||||||
|
TransportError::MultiaddrNotSupported(_) => {
|
||||||
|
panic!()
|
||||||
|
}
|
||||||
|
TransportError::Other(e) => e,
|
||||||
|
})?;
|
||||||
|
dial.await
|
||||||
|
}
|
||||||
|
.boxed())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dial_as_listener(
|
||||||
|
&mut self,
|
||||||
|
addr: Multiaddr,
|
||||||
|
) -> Result<Self::Dial, TransportError<Self::Error>> {
|
||||||
|
self.0.lock().unwrap().dial_as_listener(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn poll(
|
||||||
|
self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
|
||||||
|
Pin::new(&mut *self.0.lock().unwrap()).poll(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let (a_peer_id, mut a_transport) = create_default_transport::<quic::async_std::Provider>();
|
||||||
|
let (b_peer_id, mut b_transport) = {
|
||||||
|
let (id, transport) = create_default_transport::<quic::async_std::Provider>();
|
||||||
|
(id, DialDelay(Arc::new(Mutex::new(transport))).boxed())
|
||||||
|
};
|
||||||
|
|
||||||
|
// Spawn a
|
||||||
|
let a_addr = start_listening(&mut a_transport, "/ip6/::1/udp/0/quic-v1").await;
|
||||||
|
let listener = async_std::task::spawn(async move {
|
||||||
|
let (upgrade, _) = a_transport
|
||||||
|
.select_next_some()
|
||||||
|
.await
|
||||||
|
.into_incoming()
|
||||||
|
.unwrap();
|
||||||
|
let (peer_id, _) = upgrade.await.unwrap();
|
||||||
|
|
||||||
|
peer_id
|
||||||
|
});
|
||||||
|
|
||||||
|
// Spawn b
|
||||||
|
//
|
||||||
|
// Note that the dial is spawned on a different task than the transport allowing the transport
|
||||||
|
// task to poll the transport once and then suspend, waiting for the wakeup from the dial.
|
||||||
|
let dial = async_std::task::spawn({
|
||||||
|
let dial = b_transport.dial(a_addr).unwrap();
|
||||||
|
async { dial.await.unwrap().0 }
|
||||||
|
});
|
||||||
|
async_std::task::spawn(async move { b_transport.next().await });
|
||||||
|
|
||||||
|
let (a_connected, b_connected) = future::join(listener, dial).await;
|
||||||
|
|
||||||
|
assert_eq!(a_connected, b_peer_id);
|
||||||
|
assert_eq!(b_connected, a_peer_id);
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(feature = "async-std")]
|
#[cfg(feature = "async-std")]
|
||||||
#[async_std::test]
|
#[async_std::test]
|
||||||
#[ignore] // Transport currently does not validate PeerId. Enable once we make use of PeerId validation in rustls.
|
#[ignore] // Transport currently does not validate PeerId. Enable once we make use of PeerId validation in rustls.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user