diff --git a/transports/quic/src/transport.rs b/transports/quic/src/transport.rs index 9f66fe49..dea01c74 100644 --- a/transports/quic/src/transport.rs +++ b/transports/quic/src/transport.rs @@ -71,6 +71,7 @@ pub struct GenTransport { listeners: SelectAll>, /// Dialer for each socket family if no matching listener exists. dialer: HashMap, + dialer_waker: Option, } impl GenTransport

{ @@ -84,6 +85,7 @@ impl GenTransport

{ quinn_config, handshake_timeout, dialer: HashMap::new(), + dialer_waker: None, support_draft_29, } } @@ -178,6 +180,12 @@ impl Transport for GenTransport

{ &mut listeners[index].dialer_state } }; + + // Wakeup the task polling [`Transport::poll`] to drive the new dial. + if let Some(waker) = self.dialer_waker.take() { + waker.wake(); + } + Ok(dialer_state.new_dial(socket_addr, self.handshake_timeout, version)) } @@ -207,10 +215,14 @@ impl Transport for GenTransport

{ // Drop dialer and all pending dials so that the connection receiver is notified. self.dialer.remove(&key); } - match self.listeners.poll_next_unpin(cx) { - Poll::Ready(Some(ev)) => Poll::Ready(ev), - _ => Poll::Pending, + + if let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) { + return Poll::Ready(ev); } + + self.dialer_waker = Some(cx.waker().clone()); + + Poll::Pending } } @@ -254,12 +266,11 @@ impl Drop for Dialer { } } -/// Pending dials to be sent to the endpoint was the [`endpoint::Channel`] -/// has capacity +/// Pending dials to be sent to the endpoint once the [`endpoint::Channel`] +/// has capacity. #[derive(Default, Debug)] struct DialerState { pending_dials: VecDeque, - waker: Option, } impl DialerState { @@ -279,10 +290,6 @@ impl DialerState { self.pending_dials.push_back(message); - if let Some(waker) = self.waker.take() { - waker.wake(); - } - async move { // Our oneshot getting dropped means the message didn't make it to the endpoint driver. let connection = tx.await.map_err(|_| Error::EndpointDriverCrashed)??; @@ -307,7 +314,6 @@ impl DialerState { Err(endpoint::Disconnected {}) => return Poll::Ready(Error::EndpointDriverCrashed), } } - self.waker = Some(cx.waker().clone()); Poll::Pending } } diff --git a/transports/quic/tests/smoke.rs b/transports/quic/tests/smoke.rs index a1478645..649aca09 100644 --- a/transports/quic/tests/smoke.rs +++ b/transports/quic/tests/smoke.rs @@ -1,12 +1,15 @@ #![cfg(any(feature = "async-std", feature = "tokio"))] use futures::channel::{mpsc, oneshot}; +use futures::future::BoxFuture; use futures::future::{poll_fn, Either}; use futures::stream::StreamExt; use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt}; +use futures_timer::Delay; use libp2p_core::either::EitherOutput; use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt, SubstreamBox}; use libp2p_core::transport::{Boxed, OrTransport, TransportEvent}; +use libp2p_core::transport::{ListenerId, TransportError}; use libp2p_core::{multiaddr::Protocol, upgrade, Multiaddr, PeerId, Transport}; use libp2p_noise as noise; use libp2p_quic as quic; @@ -19,6 +22,10 @@ use std::io; use std::num::NonZeroU8; use std::task::Poll; use std::time::Duration; +use std::{ + pin::Pin, + sync::{Arc, Mutex}, +}; #[cfg(feature = "tokio")] #[tokio::test] @@ -90,6 +97,113 @@ async fn ipv4_dial_ipv6() { assert_eq!(b_connected, a_peer_id); } +/// Tests that a [`Transport::dial`] wakes up the task previously polling [`Transport::poll`]. +/// +/// See https://github.com/libp2p/rust-libp2p/pull/3306 for context. +#[cfg(feature = "async-std")] +#[async_std::test] +async fn wrapped_with_dns() { + let _ = env_logger::try_init(); + + struct DialDelay(Arc>>); + + impl Transport for DialDelay { + type Output = (PeerId, StreamMuxerBox); + type Error = std::io::Error; + type ListenerUpgrade = Pin> + Send>>; + type Dial = BoxFuture<'static, Result>; + + fn listen_on( + &mut self, + addr: Multiaddr, + ) -> Result> { + self.0.lock().unwrap().listen_on(addr) + } + + fn remove_listener(&mut self, id: ListenerId) -> bool { + self.0.lock().unwrap().remove_listener(id) + } + + fn address_translation( + &self, + listen: &Multiaddr, + observed: &Multiaddr, + ) -> Option { + self.0.lock().unwrap().address_translation(listen, observed) + } + + /// Delayed dial, i.e. calling [`Transport::dial`] on the inner [`Transport`] not within the + /// synchronous [`Transport::dial`] method, but within the [`Future`] returned by the outer + /// [`Transport::dial`]. + fn dial(&mut self, addr: Multiaddr) -> Result> { + let t = self.0.clone(); + Ok(async move { + // Simulate DNS lookup. Giving the `Transport::poll` the chance to return + // `Poll::Pending` and thus suspending its task, waiting for a wakeup from the dial + // on the inner transport below. + Delay::new(Duration::from_millis(100)).await; + + let dial = t.lock().unwrap().dial(addr).map_err(|e| match e { + TransportError::MultiaddrNotSupported(_) => { + panic!() + } + TransportError::Other(e) => e, + })?; + dial.await + } + .boxed()) + } + + fn dial_as_listener( + &mut self, + addr: Multiaddr, + ) -> Result> { + self.0.lock().unwrap().dial_as_listener(addr) + } + + fn poll( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + Pin::new(&mut *self.0.lock().unwrap()).poll(cx) + } + } + + let (a_peer_id, mut a_transport) = create_default_transport::(); + let (b_peer_id, mut b_transport) = { + let (id, transport) = create_default_transport::(); + (id, DialDelay(Arc::new(Mutex::new(transport))).boxed()) + }; + + // Spawn a + let a_addr = start_listening(&mut a_transport, "/ip6/::1/udp/0/quic-v1").await; + let listener = async_std::task::spawn(async move { + let (upgrade, _) = a_transport + .select_next_some() + .await + .into_incoming() + .unwrap(); + let (peer_id, _) = upgrade.await.unwrap(); + + peer_id + }); + + // Spawn b + // + // Note that the dial is spawned on a different task than the transport allowing the transport + // task to poll the transport once and then suspend, waiting for the wakeup from the dial. + let dial = async_std::task::spawn({ + let dial = b_transport.dial(a_addr).unwrap(); + async { dial.await.unwrap().0 } + }); + async_std::task::spawn(async move { b_transport.next().await }); + + let (a_connected, b_connected) = future::join(listener, dial).await; + + assert_eq!(a_connected, b_peer_id); + assert_eq!(b_connected, a_peer_id); +} + #[cfg(feature = "async-std")] #[async_std::test] #[ignore] // Transport currently does not validate PeerId. Enable once we make use of PeerId validation in rustls.