refactor(quic): rewrite quic using quinn

Rewrite quic using quinn instead of quinn-proto. libp2p-quic::endpoint::Driver is eliminated (and that hard quinn-proto machinery). Also:
- ECN bits are handled
- Support Generic Send Offload (GSO)

Pull-Request: #3454.
This commit is contained in:
Roman
2023-07-28 18:22:03 +08:00
committed by GitHub
parent f10f1a274a
commit 4e1fa9b8b8
17 changed files with 615 additions and 1810 deletions

37
Cargo.lock generated
View File

@ -3040,7 +3040,7 @@ dependencies = [
[[package]]
name = "libp2p-quic"
version = "0.8.0-alpha"
version = "0.9.0-alpha"
dependencies = [
"async-std",
"bytes",
@ -3058,7 +3058,7 @@ dependencies = [
"log",
"parking_lot",
"quickcheck",
"quinn-proto",
"quinn",
"rand 0.8.5",
"rustls 0.21.2",
"thiserror",
@ -4313,6 +4313,26 @@ dependencies = [
"pin-project-lite 0.1.12",
]
[[package]]
name = "quinn"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "21252f1c0fc131f1b69182db8f34837e8a69737b8251dff75636a9be0518c324"
dependencies = [
"async-io",
"async-std",
"bytes",
"futures-io",
"pin-project-lite 0.2.9",
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls 0.21.2",
"thiserror",
"tokio",
"tracing",
]
[[package]]
name = "quinn-proto"
version = "0.10.1"
@ -4330,6 +4350,19 @@ dependencies = [
"tracing",
]
[[package]]
name = "quinn-udp"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6df19e284d93757a9fb91d63672f7741b129246a669db09d1c0063071debc0c0"
dependencies = [
"bytes",
"libc",
"socket2 0.5.3",
"tracing",
"windows-sys 0.48.0",
]
[[package]]
name = "quote"
version = "1.0.32"

View File

@ -83,7 +83,7 @@ libp2p-perf = { version = "0.2.0", path = "protocols/perf" }
libp2p-ping = { version = "0.43.0", path = "protocols/ping" }
libp2p-plaintext = { version = "0.40.0", path = "transports/plaintext" }
libp2p-pnet = { version = "0.23.0", path = "transports/pnet" }
libp2p-quic = { version = "0.8.0-alpha", path = "transports/quic" }
libp2p-quic = { version = "0.9.0-alpha", path = "transports/quic" }
libp2p-relay = { version = "0.16.1", path = "protocols/relay" }
libp2p-rendezvous = { version = "0.13.0", path = "protocols/rendezvous" }
libp2p-request-response = { version = "0.25.1", path = "protocols/request-response" }

View File

@ -1,3 +1,10 @@
## 0.9.0-alpha - unreleased
- Use `quinn` instead of `quinn-proto`.
See [PR 3454].
[PR 3454]: https://github.com/libp2p/rust-libp2p/pull/3454
## 0.8.0-alpha
- Raise MSRV to 1.65.

View File

@ -1,6 +1,6 @@
[package]
name = "libp2p-quic"
version = "0.8.0-alpha"
version = "0.9.0-alpha"
authors = ["Parity Technologies <admin@parity.io>"]
edition = "2021"
rust-version = { workspace = true }
@ -19,15 +19,15 @@ libp2p-tls = { workspace = true }
libp2p-identity = { workspace = true }
log = "0.4"
parking_lot = "0.12.0"
quinn-proto = { version = "0.10.1", default-features = false, features = ["tls-rustls"] }
quinn = { version = "0.10.1", default-features = false, features = ["tls-rustls", "futures-io"] }
rand = "0.8.5"
rustls = { version = "0.21.2", default-features = false }
thiserror = "1.0.44"
tokio = { version = "1.29.1", default-features = false, features = ["net", "rt", "time"], optional = true }
[features]
tokio = ["dep:tokio", "if-watch/tokio"]
async-std = ["dep:async-std", "if-watch/smol"]
tokio = ["dep:tokio", "if-watch/tokio", "quinn/runtime-tokio"]
async-std = ["dep:async-std", "if-watch/smol", "quinn/runtime-async-std"]
# Passing arguments to the docsrs builder in order to properly document cfg's.
# More information: https://docs.rs/about/builds#cross-compiling

View File

@ -0,0 +1,142 @@
// Copyright 2017-2020 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use quinn::VarInt;
use std::{sync::Arc, time::Duration};
/// Config for the transport.
#[derive(Clone)]
pub struct Config {
/// Timeout for the initial handshake when establishing a connection.
/// The actual timeout is the minimum of this and the [`Config::max_idle_timeout`].
pub handshake_timeout: Duration,
/// Maximum duration of inactivity in ms to accept before timing out the connection.
pub max_idle_timeout: u32,
/// Period of inactivity before sending a keep-alive packet.
/// Must be set lower than the idle_timeout of both
/// peers to be effective.
///
/// See [`quinn::TransportConfig::keep_alive_interval`] for more
/// info.
pub keep_alive_interval: Duration,
/// Maximum number of incoming bidirectional streams that may be open
/// concurrently by the remote peer.
pub max_concurrent_stream_limit: u32,
/// Max unacknowledged data in bytes that may be send on a single stream.
pub max_stream_data: u32,
/// Max unacknowledged data in bytes that may be send in total on all streams
/// of a connection.
pub max_connection_data: u32,
/// Support QUIC version draft-29 for dialing and listening.
///
/// Per default only QUIC Version 1 / [`libp2p_core::multiaddr::Protocol::QuicV1`]
/// is supported.
///
/// If support for draft-29 is enabled servers support draft-29 and version 1 on all
/// QUIC listening addresses.
/// As client the version is chosen based on the remote's address.
pub support_draft_29: bool,
/// TLS client config for the inner [`quinn::ClientConfig`].
client_tls_config: Arc<rustls::ClientConfig>,
/// TLS server config for the inner [`quinn::ServerConfig`].
server_tls_config: Arc<rustls::ServerConfig>,
}
impl Config {
/// Creates a new configuration object with default values.
pub fn new(keypair: &libp2p_identity::Keypair) -> Self {
let client_tls_config = Arc::new(libp2p_tls::make_client_config(keypair, None).unwrap());
let server_tls_config = Arc::new(libp2p_tls::make_server_config(keypair).unwrap());
Self {
client_tls_config,
server_tls_config,
support_draft_29: false,
handshake_timeout: Duration::from_secs(5),
max_idle_timeout: 30 * 1000,
max_concurrent_stream_limit: 256,
keep_alive_interval: Duration::from_secs(15),
max_connection_data: 15_000_000,
// Ensure that one stream is not consuming the whole connection.
max_stream_data: 10_000_000,
}
}
}
/// Represents the inner configuration for [`quinn`].
#[derive(Debug, Clone)]
pub(crate) struct QuinnConfig {
pub(crate) client_config: quinn::ClientConfig,
pub(crate) server_config: quinn::ServerConfig,
pub(crate) endpoint_config: quinn::EndpointConfig,
}
impl From<Config> for QuinnConfig {
fn from(config: Config) -> QuinnConfig {
let Config {
client_tls_config,
server_tls_config,
max_idle_timeout,
max_concurrent_stream_limit,
keep_alive_interval,
max_connection_data,
max_stream_data,
support_draft_29,
handshake_timeout: _,
} = config;
let mut transport = quinn::TransportConfig::default();
// Disable uni-directional streams.
transport.max_concurrent_uni_streams(0u32.into());
transport.max_concurrent_bidi_streams(max_concurrent_stream_limit.into());
// Disable datagrams.
transport.datagram_receive_buffer_size(None);
transport.keep_alive_interval(Some(keep_alive_interval));
transport.max_idle_timeout(Some(VarInt::from_u32(max_idle_timeout).into()));
transport.allow_spin(false);
transport.stream_receive_window(max_stream_data.into());
transport.receive_window(max_connection_data.into());
let transport = Arc::new(transport);
let mut server_config = quinn::ServerConfig::with_crypto(server_tls_config);
server_config.transport = Arc::clone(&transport);
// Disables connection migration.
// Long-term this should be enabled, however we then need to handle address change
// on connections in the `Connection`.
server_config.migration(false);
let mut client_config = quinn::ClientConfig::new(client_tls_config);
client_config.transport_config(transport);
let mut endpoint_config = quinn::EndpointConfig::default();
if !support_draft_29 {
endpoint_config.supported_versions(vec![1]);
}
QuinnConfig {
client_config,
server_config,
endpoint_config,
}
}
}

View File

@ -19,409 +19,113 @@
// DEALINGS IN THE SOFTWARE.
mod connecting;
mod substream;
mod stream;
use crate::{
endpoint::{self, ToEndpoint},
Error,
};
pub use connecting::Connecting;
pub use substream::Substream;
use substream::{SubstreamState, WriteState};
pub use stream::Stream;
use futures::{channel::mpsc, ready, FutureExt, StreamExt};
use futures_timer::Delay;
use crate::{ConnectionError, Error};
use futures::{future::BoxFuture, FutureExt};
use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent};
use parking_lot::Mutex;
use std::{
any::Any,
collections::HashMap,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
time::Instant,
task::{Context, Poll},
};
/// State for a single opened QUIC connection.
#[derive(Debug)]
pub struct Connection {
/// State shared with the substreams.
state: Arc<Mutex<State>>,
/// Channel to the [`endpoint::Driver`] that drives the [`quinn_proto::Endpoint`] that
/// this connection belongs to.
endpoint_channel: endpoint::Channel,
/// Pending message to be sent to the [`quinn_proto::Endpoint`] in the [`endpoint::Driver`].
pending_to_endpoint: Option<ToEndpoint>,
/// Events that the [`quinn_proto::Endpoint`] will send in destination to our local
/// [`quinn_proto::Connection`].
from_endpoint: mpsc::Receiver<quinn_proto::ConnectionEvent>,
/// Identifier for this connection according to the [`quinn_proto::Endpoint`].
/// Used when sending messages to the endpoint.
connection_id: quinn_proto::ConnectionHandle,
/// `Future` that triggers at the [`Instant`] that [`quinn_proto::Connection::poll_timeout`]
/// indicates.
next_timeout: Option<(Delay, Instant)>,
/// Underlying connection.
connection: quinn::Connection,
/// Future for accepting a new incoming bidirectional stream.
incoming: Option<
BoxFuture<'static, Result<(quinn::SendStream, quinn::RecvStream), quinn::ConnectionError>>,
>,
/// Future for opening a new outgoing bidirectional stream.
outgoing: Option<
BoxFuture<'static, Result<(quinn::SendStream, quinn::RecvStream), quinn::ConnectionError>>,
>,
/// Future to wait for the connection to be closed.
closing: Option<BoxFuture<'static, quinn::ConnectionError>>,
}
impl Connection {
/// Build a [`Connection`] from raw components.
///
/// This function assumes that there exists a [`Driver`](super::endpoint::Driver)
/// that will process the messages sent to `EndpointChannel::to_endpoint` and send us messages
/// on `from_endpoint`.
///
/// `connection_id` is used to identify the local connection in the messages sent to
/// `to_endpoint`.
///
/// This function assumes that the [`quinn_proto::Connection`] is completely fresh and none of
/// This function assumes that the [`quinn::Connection`] is completely fresh and none of
/// its methods has ever been called. Failure to comply might lead to logic errors and panics.
pub(crate) fn from_quinn_connection(
endpoint_channel: endpoint::Channel,
connection: quinn_proto::Connection,
connection_id: quinn_proto::ConnectionHandle,
from_endpoint: mpsc::Receiver<quinn_proto::ConnectionEvent>,
) -> Self {
let state = State {
connection,
substreams: HashMap::new(),
poll_connection_waker: None,
poll_inbound_waker: None,
poll_outbound_waker: None,
};
fn new(connection: quinn::Connection) -> Self {
Self {
endpoint_channel,
pending_to_endpoint: None,
next_timeout: None,
from_endpoint,
connection_id,
state: Arc::new(Mutex::new(state)),
}
}
/// The address that the local socket is bound to.
pub(crate) fn local_addr(&self) -> &SocketAddr {
self.endpoint_channel.socket_addr()
}
/// Returns the address of the node we're connected to.
pub(crate) fn remote_addr(&self) -> SocketAddr {
self.state.lock().connection.remote_address()
}
/// Identity of the remote peer inferred from the handshake.
///
/// `None` if the handshake is not complete yet, i.e. [`Self::poll_event`]
/// has not yet reported a [`quinn_proto::Event::Connected`]
fn peer_identity(&self) -> Option<Box<dyn Any>> {
self.state
.lock()
.connection
.crypto_session()
.peer_identity()
}
/// Polls the connection for an event that happened on it.
///
/// `quinn::proto::Connection` is polled in the order instructed in their docs:
/// 1. [`quinn_proto::Connection::poll_transmit`]
/// 2. [`quinn_proto::Connection::poll_timeout`]
/// 3. [`quinn_proto::Connection::poll_endpoint_events`]
/// 4. [`quinn_proto::Connection::poll`]
fn poll_event(&mut self, cx: &mut Context<'_>) -> Poll<Option<quinn_proto::Event>> {
let mut inner = self.state.lock();
loop {
// Sending the pending event to the endpoint. If the endpoint is too busy, we just
// stop the processing here.
// We don't deliver substream-related events to the user as long as
// `to_endpoint` is full. This should propagate the back-pressure of `to_endpoint`
// being full to the user.
if let Some(to_endpoint) = self.pending_to_endpoint.take() {
match self.endpoint_channel.try_send(to_endpoint, cx) {
Ok(Ok(())) => {}
Ok(Err(to_endpoint)) => {
self.pending_to_endpoint = Some(to_endpoint);
return Poll::Pending;
}
Err(endpoint::Disconnected {}) => {
return Poll::Ready(None);
}
}
}
match self.from_endpoint.poll_next_unpin(cx) {
Poll::Ready(Some(event)) => {
inner.connection.handle_event(event);
continue;
}
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Pending => {}
}
// The maximum amount of segments which can be transmitted in a single Transmit
// if a platform supports Generic Send Offload (GSO).
// Set to 1 for now since not all platforms support GSO.
// TODO: Fix for platforms that support GSO.
let max_datagrams = 1;
// Poll the connection for packets to send on the UDP socket and try to send them on
// `to_endpoint`.
if let Some(transmit) = inner
.connection
.poll_transmit(Instant::now(), max_datagrams)
{
// TODO: ECN bits not handled
self.pending_to_endpoint = Some(ToEndpoint::SendUdpPacket(transmit));
continue;
}
match inner.connection.poll_timeout() {
Some(timeout) => match self.next_timeout {
Some((_, when)) if when == timeout => {}
_ => {
let now = Instant::now();
// 0ns if now > when
let duration = timeout.duration_since(now);
let next_timeout = Delay::new(duration);
self.next_timeout = Some((next_timeout, timeout))
}
},
None => self.next_timeout = None,
}
if let Some((timeout, when)) = self.next_timeout.as_mut() {
if timeout.poll_unpin(cx).is_ready() {
inner.connection.handle_timeout(*when);
continue;
}
}
// The connection also needs to be able to send control messages to the endpoint. This is
// handled here, and we try to send them on `to_endpoint` as well.
if let Some(event) = inner.connection.poll_endpoint_events() {
let connection_id = self.connection_id;
self.pending_to_endpoint = Some(ToEndpoint::ProcessConnectionEvent {
connection_id,
event,
});
continue;
}
// The final step consists in returning the events related to the various substreams.
if let Some(ev) = inner.connection.poll() {
return Poll::Ready(Some(ev));
}
return Poll::Pending;
connection,
incoming: None,
outgoing: None,
closing: None,
}
}
}
impl StreamMuxer for Connection {
type Substream = Substream;
type Substream = Stream;
type Error = Error;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
while let Poll::Ready(event) = self.poll_event(cx) {
let mut inner = self.state.lock();
let event = match event {
Some(event) => event,
None => return Poll::Ready(Err(Error::EndpointDriverCrashed)),
};
match event {
quinn_proto::Event::Connected | quinn_proto::Event::HandshakeDataReady => {
debug_assert!(
false,
"Unexpected event {event:?} on established QUIC connection"
);
}
quinn_proto::Event::ConnectionLost { reason } => {
inner
.connection
.close(Instant::now(), From::from(0u32), Default::default());
inner.substreams.values_mut().for_each(|s| s.wake_all());
return Poll::Ready(Err(Error::Connection(reason.into())));
}
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Opened {
dir: quinn_proto::Dir::Bi,
}) => {
if let Some(waker) = inner.poll_outbound_waker.take() {
waker.wake();
}
}
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Available {
dir: quinn_proto::Dir::Bi,
}) => {
if let Some(waker) = inner.poll_inbound_waker.take() {
waker.wake();
}
}
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Readable { id }) => {
if let Some(substream) = inner.substreams.get_mut(&id) {
if let Some(waker) = substream.read_waker.take() {
waker.wake();
}
}
}
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Writable { id }) => {
if let Some(substream) = inner.substreams.get_mut(&id) {
if let Some(waker) = substream.write_waker.take() {
waker.wake();
}
}
}
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Finished { id }) => {
if let Some(substream) = inner.substreams.get_mut(&id) {
if matches!(
substream.write_state,
WriteState::Open | WriteState::Closing
) {
substream.write_state = WriteState::Closed;
}
if let Some(waker) = substream.write_waker.take() {
waker.wake();
}
if let Some(waker) = substream.close_waker.take() {
waker.wake();
}
}
}
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Stopped {
id,
error_code: _,
}) => {
if let Some(substream) = inner.substreams.get_mut(&id) {
substream.write_state = WriteState::Stopped;
if let Some(waker) = substream.write_waker.take() {
waker.wake();
}
if let Some(waker) = substream.close_waker.take() {
waker.wake();
}
}
}
quinn_proto::Event::DatagramReceived
| quinn_proto::Event::Stream(quinn_proto::StreamEvent::Available {
dir: quinn_proto::Dir::Uni,
})
| quinn_proto::Event::Stream(quinn_proto::StreamEvent::Opened {
dir: quinn_proto::Dir::Uni,
}) => {
unreachable!("We don't use datagrams or unidirectional streams.")
}
}
}
// TODO: If connection migration is enabled (currently disabled) address
// change on the connection needs to be handled.
self.state.lock().poll_connection_waker = Some(cx.waker().clone());
Poll::Pending
}
fn poll_inbound(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
let mut inner = self.state.lock();
let this = self.get_mut();
let substream_id = match inner.connection.streams().accept(quinn_proto::Dir::Bi) {
Some(id) => {
inner.poll_inbound_waker = None;
id
}
None => {
inner.poll_inbound_waker = Some(cx.waker().clone());
return Poll::Pending;
}
};
inner.substreams.insert(substream_id, Default::default());
let substream = Substream::new(substream_id, self.state.clone());
let incoming = this.incoming.get_or_insert_with(|| {
let connection = this.connection.clone();
async move { connection.accept_bi().await }.boxed()
});
Poll::Ready(Ok(substream))
let (send, recv) = futures::ready!(incoming.poll_unpin(cx)).map_err(ConnectionError)?;
this.incoming.take();
let stream = Stream::new(send, recv);
Poll::Ready(Ok(stream))
}
fn poll_outbound(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<Self::Substream, Self::Error>> {
let mut inner = self.state.lock();
let substream_id = match inner.connection.streams().open(quinn_proto::Dir::Bi) {
Some(id) => {
inner.poll_outbound_waker = None;
id
}
None => {
inner.poll_outbound_waker = Some(cx.waker().clone());
return Poll::Pending;
}
let this = self.get_mut();
let outgoing = this.outgoing.get_or_insert_with(|| {
let connection = this.connection.clone();
async move { connection.open_bi().await }.boxed()
});
let (send, recv) = futures::ready!(outgoing.poll_unpin(cx)).map_err(ConnectionError)?;
this.outgoing.take();
let stream = Stream::new(send, recv);
Poll::Ready(Ok(stream))
}
fn poll(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
// TODO: If connection migration is enabled (currently disabled) address
// change on the connection needs to be handled.
Poll::Pending
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let this = self.get_mut();
let closing = this.closing.get_or_insert_with(|| {
this.connection.close(From::from(0u32), &[]);
let connection = this.connection.clone();
async move { connection.closed().await }.boxed()
});
match futures::ready!(closing.poll_unpin(cx)) {
// Expected error given that `connection.close` was called above.
quinn::ConnectionError::LocallyClosed => {}
error => return Poll::Ready(Err(Error::Connection(ConnectionError(error)))),
};
inner.substreams.insert(substream_id, Default::default());
let substream = Substream::new(substream_id, self.state.clone());
Poll::Ready(Ok(substream))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut inner = self.state.lock();
if inner.connection.is_drained() {
return Poll::Ready(Ok(()));
}
for substream in inner.substreams.keys().cloned().collect::<Vec<_>>() {
let _ = inner.connection.send_stream(substream).finish();
}
if inner.connection.streams().send_streams() == 0 && !inner.connection.is_closed() {
inner
.connection
.close(Instant::now(), From::from(0u32), Default::default())
}
drop(inner);
loop {
match ready!(self.poll_event(cx)) {
Some(quinn_proto::Event::ConnectionLost { .. }) => return Poll::Ready(Ok(())),
None => return Poll::Ready(Err(Error::EndpointDriverCrashed)),
_ => {}
}
}
}
}
impl Drop for Connection {
fn drop(&mut self) {
let to_endpoint = ToEndpoint::ProcessConnectionEvent {
connection_id: self.connection_id,
event: quinn_proto::EndpointEvent::drained(),
};
self.endpoint_channel.send_on_drop(to_endpoint);
}
}
/// Mutex-protected state of [`Connection`].
#[derive(Debug)]
pub struct State {
/// The QUIC inner state machine for this specific connection.
connection: quinn_proto::Connection,
/// State of all the substreams that the muxer reports as open.
pub substreams: HashMap<quinn_proto::StreamId, SubstreamState>,
/// Waker to wake if a new outbound substream is opened.
pub poll_outbound_waker: Option<Waker>,
/// Waker to wake if a new inbound substream was happened.
pub poll_inbound_waker: Option<Waker>,
/// Waker to wake if the connection should be polled again.
pub poll_connection_waker: Option<Waker>,
}
impl State {
fn unchecked_substream_state(&mut self, id: quinn_proto::StreamId) -> &mut SubstreamState {
self.substreams
.get_mut(&id)
.expect("Substream should be known.")
Poll::Ready(Ok(()))
}
}

View File

@ -20,9 +20,12 @@
//! Future that drives a QUIC connection until is has performed its TLS handshake.
use crate::{Connection, Error};
use crate::{Connection, ConnectionError, Error};
use futures::prelude::*;
use futures::{
future::{select, Either, FutureExt, Select},
prelude::*,
};
use futures_timer::Delay;
use libp2p_identity::PeerId;
use std::{
@ -34,64 +37,46 @@ use std::{
/// A QUIC connection currently being negotiated.
#[derive(Debug)]
pub struct Connecting {
connection: Option<Connection>,
timeout: Delay,
connecting: Select<quinn::Connecting, Delay>,
}
impl Connecting {
pub(crate) fn new(connection: Connection, timeout: Duration) -> Self {
pub(crate) fn new(connection: quinn::Connecting, timeout: Duration) -> Self {
Connecting {
connection: Some(connection),
timeout: Delay::new(timeout),
connecting: select(connection, Delay::new(timeout)),
}
}
}
impl Connecting {
/// Returns the address of the node we're connected to.
/// Panics if the connection is still handshaking.
fn remote_peer_id(connection: &quinn::Connection) -> PeerId {
let identity = connection
.peer_identity()
.expect("connection got identity because it passed TLS handshake; qed");
let certificates: Box<Vec<rustls::Certificate>> =
identity.downcast().expect("we rely on rustls feature; qed");
let end_entity = certificates
.get(0)
.expect("there should be exactly one certificate; qed");
let p2p_cert = libp2p_tls::certificate::parse(end_entity)
.expect("the certificate was validated during TLS handshake; qed");
p2p_cert.peer_id()
}
}
impl Future for Connecting {
type Output = Result<(PeerId, Connection), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let connection = self
.connection
.as_mut()
.expect("Future polled after it has completed");
let connection = match futures::ready!(self.connecting.poll_unpin(cx)) {
Either::Right(_) => return Poll::Ready(Err(Error::HandshakeTimedOut)),
Either::Left((connection, _)) => connection.map_err(ConnectionError)?,
};
loop {
let event = match connection.poll_event(cx) {
Poll::Ready(Some(event)) => event,
Poll::Ready(None) => return Poll::Ready(Err(Error::EndpointDriverCrashed)),
Poll::Pending => {
return self
.timeout
.poll_unpin(cx)
.map(|()| Err(Error::HandshakeTimedOut));
}
};
match event {
quinn_proto::Event::Connected => {
// Parse the remote's Id identity from the certificate.
let identity = connection
.peer_identity()
.expect("connection got identity because it passed TLS handshake; qed");
let certificates: Box<Vec<rustls::Certificate>> =
identity.downcast().expect("we rely on rustls feature; qed");
let end_entity = certificates
.get(0)
.expect("there should be exactly one certificate; qed");
let p2p_cert = libp2p_tls::certificate::parse(end_entity)
.expect("the certificate was validated during TLS handshake; qed");
let peer_id = p2p_cert.peer_id();
return Poll::Ready(Ok((peer_id, self.connection.take().unwrap())));
}
quinn_proto::Event::ConnectionLost { reason } => {
return Poll::Ready(Err(Error::Connection(reason.into())))
}
quinn_proto::Event::HandshakeDataReady | quinn_proto::Event::Stream(_) => {}
quinn_proto::Event::DatagramReceived => {
debug_assert!(false, "Datagrams are not supported")
}
}
}
let peer_id = Self::remote_peer_id(&connection);
let muxer = Connection::new(connection);
Poll::Ready(Ok((peer_id, muxer)))
}
}

View File

@ -0,0 +1,86 @@
// Copyright 2022 Protocol Labs.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use std::{
io::{self},
pin::Pin,
task::{Context, Poll},
};
use futures::{AsyncRead, AsyncWrite};
/// A single stream on a connection
pub struct Stream {
/// A send part of the stream
send: quinn::SendStream,
/// A receive part of the stream
recv: quinn::RecvStream,
/// Whether the stream is closed or not
close_result: Option<Result<(), io::ErrorKind>>,
}
impl Stream {
pub(super) fn new(send: quinn::SendStream, recv: quinn::RecvStream) -> Self {
Self {
send,
recv,
close_result: None,
}
}
}
impl AsyncRead for Stream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if let Some(close_result) = self.close_result {
if close_result.is_err() {
return Poll::Ready(Ok(0));
}
}
Pin::new(&mut self.recv).poll_read(cx, buf)
}
}
impl AsyncWrite for Stream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.send).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
Pin::new(&mut self.send).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
if let Some(close_result) = self.close_result {
// For some reason poll_close needs to be 'fuse'able
return Poll::Ready(close_result.map_err(Into::into));
}
let close_result = futures::ready!(Pin::new(&mut self.send).poll_close(cx));
self.close_result = Some(close_result.as_ref().map_err(|e| e.kind()).copied());
Poll::Ready(close_result)
}
}

View File

@ -1,257 +0,0 @@
// Copyright 2022 Protocol Labs.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use std::{
io::{self, Write},
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
};
use futures::{AsyncRead, AsyncWrite};
use parking_lot::Mutex;
use super::State;
/// Wakers for the [`AsyncRead`] and [`AsyncWrite`] on a substream.
#[derive(Debug, Default, Clone)]
pub struct SubstreamState {
/// Waker to wake if the substream becomes readable.
pub read_waker: Option<Waker>,
/// Waker to wake if the substream becomes writable, closed or stopped.
pub write_waker: Option<Waker>,
/// Waker to wake if the substream becomes closed or stopped.
pub close_waker: Option<Waker>,
pub write_state: WriteState,
}
impl SubstreamState {
/// Wake all wakers for reading, writing and closed the stream.
pub fn wake_all(&mut self) {
if let Some(waker) = self.read_waker.take() {
waker.wake();
}
if let Some(waker) = self.write_waker.take() {
waker.wake();
}
if let Some(waker) = self.close_waker.take() {
waker.wake();
}
}
}
/// A single stream on a connection
#[derive(Debug)]
pub struct Substream {
/// The id of the stream.
id: quinn_proto::StreamId,
/// The state of the [`super::Connection`] this stream belongs to.
state: Arc<Mutex<State>>,
}
impl Substream {
pub fn new(id: quinn_proto::StreamId, state: Arc<Mutex<State>>) -> Self {
Self { id, state }
}
}
impl AsyncRead for Substream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut state = self.state.lock();
let mut stream = state.connection.recv_stream(self.id);
let mut chunks = match stream.read(true) {
Ok(chunks) => chunks,
Err(quinn_proto::ReadableError::UnknownStream) => {
return Poll::Ready(Ok(0));
}
Err(quinn_proto::ReadableError::IllegalOrderedRead) => {
unreachable!(
"Illegal ordered read can only happen if `stream.read(false)` is used."
);
}
};
let mut bytes = 0;
let mut pending = false;
let mut error = None;
loop {
if buf.is_empty() {
// Chunks::next will continue returning `Ok(Some(_))` with an
// empty chunk if there is no space left in the buffer, so we
// break early here.
break;
}
let chunk = match chunks.next(buf.len()) {
Ok(Some(chunk)) => chunk,
Ok(None) => break,
Err(err @ quinn_proto::ReadError::Reset(_)) => {
error = Some(Err(io::Error::new(io::ErrorKind::ConnectionReset, err)));
break;
}
Err(quinn_proto::ReadError::Blocked) => {
pending = true;
break;
}
};
buf.write_all(&chunk.bytes).expect("enough buffer space");
bytes += chunk.bytes.len();
}
if chunks.finalize().should_transmit() {
if let Some(waker) = state.poll_connection_waker.take() {
waker.wake();
}
}
if let Some(err) = error {
return Poll::Ready(err);
}
if pending && bytes == 0 {
let substream_state = state.unchecked_substream_state(self.id);
substream_state.read_waker = Some(cx.waker().clone());
return Poll::Pending;
}
Poll::Ready(Ok(bytes))
}
}
impl AsyncWrite for Substream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
let mut state = self.state.lock();
match state.connection.send_stream(self.id).write(buf) {
Ok(bytes) => {
if let Some(waker) = state.poll_connection_waker.take() {
waker.wake();
}
Poll::Ready(Ok(bytes))
}
Err(quinn_proto::WriteError::Blocked) => {
let substream_state = state.unchecked_substream_state(self.id);
substream_state.write_waker = Some(cx.waker().clone());
Poll::Pending
}
Err(err @ quinn_proto::WriteError::Stopped(_)) => {
Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, err)))
}
Err(quinn_proto::WriteError::UnknownStream) => {
Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
// quinn doesn't support flushing, calling close will flush all substreams.
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
let mut inner = self.state.lock();
let substream_state = inner.unchecked_substream_state(self.id);
match substream_state.write_state {
WriteState::Open => {}
WriteState::Closing => {
substream_state.close_waker = Some(cx.waker().clone());
return Poll::Pending;
}
WriteState::Closed => return Poll::Ready(Ok(())),
WriteState::Stopped => {
let err = quinn_proto::FinishError::Stopped(0u32.into());
return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, err)));
}
}
match inner.connection.send_stream(self.id).finish() {
Ok(()) => {
let substream_state = inner.unchecked_substream_state(self.id);
substream_state.close_waker = Some(cx.waker().clone());
substream_state.write_state = WriteState::Closing;
Poll::Pending
}
Err(err @ quinn_proto::FinishError::Stopped(_)) => {
Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, err)))
}
Err(quinn_proto::FinishError::UnknownStream) => {
// We never make up IDs so the stream must have existed at some point if we get to here.
// `UnknownStream` is also emitted in case the stream is already finished, hence just
// return `Ok(())` here.
Poll::Ready(Ok(()))
}
}
}
}
impl Drop for Substream {
fn drop(&mut self) {
let mut state = self.state.lock();
state.substreams.remove(&self.id);
// Send `STOP_STREAM` if the remote did not finish the stream yet.
// We have to manually check the read stream since we might have
// received a `FIN` (without any other stream data) after the last
// time we tried to read.
let mut is_read_done = false;
if let Ok(mut chunks) = state.connection.recv_stream(self.id).read(true) {
if let Ok(chunk) = chunks.next(0) {
is_read_done = chunk.is_none();
}
let _ = chunks.finalize();
}
if !is_read_done {
let _ = state.connection.recv_stream(self.id).stop(0u32.into());
}
// Close the writing side.
let mut send_stream = state.connection.send_stream(self.id);
match send_stream.finish() {
Ok(()) => {}
// Already finished or reset, which is fine.
Err(quinn_proto::FinishError::UnknownStream) => {}
Err(quinn_proto::FinishError::Stopped(reason)) => {
let _ = send_stream.reset(reason);
}
}
}
}
#[derive(Debug, Default, Clone)]
pub enum WriteState {
/// The stream is open for writing.
#[default]
Open,
/// The writing side of the stream is closing.
Closing,
/// All data was successfully sent to the remote and the stream closed,
/// i.e. a [`quinn_proto::StreamEvent::Finished`] was reported for it.
Closed,
/// The stream was stopped by the remote before all data could be
/// sent.
Stopped,
}

View File

@ -1,674 +0,0 @@
// Copyright 2017-2020 Parity Technologies (UK) Ltd.
//
// Permission is hereby granted, free of charge, to any person obtaining a
// copy of this software and associated documentation files (the "Software"),
// to deal in the Software without restriction, including without limitation
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
// and/or sell copies of the Software, and to permit persons to whom the
// Software is furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use crate::{
provider::Provider,
transport::{ProtocolVersion, SocketFamily},
ConnectError, Connection, Error,
};
use bytes::BytesMut;
use futures::{
channel::{mpsc, oneshot},
prelude::*,
};
use quinn_proto::VarInt;
use std::{
collections::HashMap,
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
ops::ControlFlow,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::{Duration, Instant},
};
// The `Driver` drops packets if the channel to the connection
// or transport is full.
// Set capacity 10 to avoid unnecessary packet drops if the receiver
// is only very briefly busy, but not buffer a large amount of packets
// if it is blocked longer.
const CHANNEL_CAPACITY: usize = 10;
/// Config for the transport.
#[derive(Clone)]
pub struct Config {
/// Timeout for the initial handshake when establishing a connection.
/// The actual timeout is the minimum of this an the [`Config::max_idle_timeout`].
pub handshake_timeout: Duration,
/// Maximum duration of inactivity in ms to accept before timing out the connection.
pub max_idle_timeout: u32,
/// Period of inactivity before sending a keep-alive packet.
/// Must be set lower than the idle_timeout of both
/// peers to be effective.
///
/// See [`quinn_proto::TransportConfig::keep_alive_interval`] for more
/// info.
pub keep_alive_interval: Duration,
/// Maximum number of incoming bidirectional streams that may be open
/// concurrently by the remote peer.
pub max_concurrent_stream_limit: u32,
/// Max unacknowledged data in bytes that may be send on a single stream.
pub max_stream_data: u32,
/// Max unacknowledged data in bytes that may be send in total on all streams
/// of a connection.
pub max_connection_data: u32,
/// Support QUIC version draft-29 for dialing and listening.
///
/// Per default only QUIC Version 1 / [`libp2p_core::multiaddr::Protocol::QuicV1`]
/// is supported.
///
/// If support for draft-29 is enabled servers support draft-29 and version 1 on all
/// QUIC listening addresses.
/// As client the version is chosen based on the remote's address.
pub support_draft_29: bool,
/// TLS client config for the inner [`quinn_proto::ClientConfig`].
client_tls_config: Arc<rustls::ClientConfig>,
/// TLS server config for the inner [`quinn_proto::ServerConfig`].
server_tls_config: Arc<rustls::ServerConfig>,
}
impl Config {
/// Creates a new configuration object with default values.
pub fn new(keypair: &libp2p_identity::Keypair) -> Self {
let client_tls_config = Arc::new(libp2p_tls::make_client_config(keypair, None).unwrap());
let server_tls_config = Arc::new(libp2p_tls::make_server_config(keypair).unwrap());
Self {
client_tls_config,
server_tls_config,
support_draft_29: false,
handshake_timeout: Duration::from_secs(5),
max_idle_timeout: 30 * 1000,
max_concurrent_stream_limit: 256,
keep_alive_interval: Duration::from_secs(15),
max_connection_data: 15_000_000,
// Ensure that one stream is not consuming the whole connection.
max_stream_data: 10_000_000,
}
}
}
/// Represents the inner configuration for [`quinn_proto`].
#[derive(Debug, Clone)]
pub(crate) struct QuinnConfig {
client_config: quinn_proto::ClientConfig,
server_config: Arc<quinn_proto::ServerConfig>,
endpoint_config: Arc<quinn_proto::EndpointConfig>,
}
impl From<Config> for QuinnConfig {
fn from(config: Config) -> QuinnConfig {
let Config {
client_tls_config,
server_tls_config,
max_idle_timeout,
max_concurrent_stream_limit,
keep_alive_interval,
max_connection_data,
max_stream_data,
support_draft_29,
handshake_timeout: _,
} = config;
let mut transport = quinn_proto::TransportConfig::default();
// Disable uni-directional streams.
transport.max_concurrent_uni_streams(0u32.into());
transport.max_concurrent_bidi_streams(max_concurrent_stream_limit.into());
// Disable datagrams.
transport.datagram_receive_buffer_size(None);
transport.keep_alive_interval(Some(keep_alive_interval));
transport.max_idle_timeout(Some(VarInt::from_u32(max_idle_timeout).into()));
transport.allow_spin(false);
transport.stream_receive_window(max_stream_data.into());
transport.receive_window(max_connection_data.into());
let transport = Arc::new(transport);
let mut server_config = quinn_proto::ServerConfig::with_crypto(server_tls_config);
server_config.transport = Arc::clone(&transport);
// Disables connection migration.
// Long-term this should be enabled, however we then need to handle address change
// on connections in the `Connection`.
server_config.migration(false);
let mut client_config = quinn_proto::ClientConfig::new(client_tls_config);
client_config.transport_config(transport);
let mut endpoint_config = quinn_proto::EndpointConfig::default();
if !support_draft_29 {
endpoint_config.supported_versions(vec![1]);
}
QuinnConfig {
client_config,
server_config: Arc::new(server_config),
endpoint_config: Arc::new(endpoint_config),
}
}
}
/// Channel used to send commands to the [`Driver`].
#[derive(Debug, Clone)]
pub(crate) struct Channel {
/// Channel to the background of the endpoint.
to_endpoint: mpsc::Sender<ToEndpoint>,
/// Address that the socket is bound to.
/// Note: this may be a wildcard ip address.
socket_addr: SocketAddr,
}
impl Channel {
/// Builds a new endpoint that is listening on the [`SocketAddr`].
pub(crate) fn new_bidirectional<P: Provider>(
quinn_config: QuinnConfig,
socket_addr: SocketAddr,
) -> Result<(Self, mpsc::Receiver<Connection>), Error> {
// Channel for forwarding new inbound connections to the listener.
let (new_connections_tx, new_connections_rx) = mpsc::channel(CHANNEL_CAPACITY);
let endpoint = Self::new::<P>(quinn_config, socket_addr, Some(new_connections_tx))?;
Ok((endpoint, new_connections_rx))
}
/// Builds a new endpoint that only supports outbound connections.
pub(crate) fn new_dialer<P: Provider>(
quinn_config: QuinnConfig,
socket_family: SocketFamily,
) -> Result<Self, Error> {
let socket_addr = match socket_family {
SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
};
Self::new::<P>(quinn_config, socket_addr, None)
}
/// Spawn a new [`Driver`] that runs in the background.
fn new<P: Provider>(
quinn_config: QuinnConfig,
socket_addr: SocketAddr,
new_connections: Option<mpsc::Sender<Connection>>,
) -> Result<Self, Error> {
let socket = std::net::UdpSocket::bind(socket_addr)?;
// NOT blocking, as per man:bind(2), as we pass an IP address.
socket.set_nonblocking(true)?;
// Capacity 0 to back-pressure the rest of the application if
// the udp socket is busy.
let (to_endpoint_tx, to_endpoint_rx) = mpsc::channel(0);
let channel = Self {
to_endpoint: to_endpoint_tx,
socket_addr: socket.local_addr()?,
};
let server_config = new_connections
.is_some()
.then_some(quinn_config.server_config);
let provider_socket = P::from_socket(socket)?;
let driver = Driver::<P>::new(
quinn_config.endpoint_config,
quinn_config.client_config,
new_connections,
server_config,
channel.clone(),
provider_socket,
to_endpoint_rx,
);
// Drive the endpoint future in the background.
P::spawn(driver);
Ok(channel)
}
pub(crate) fn socket_addr(&self) -> &SocketAddr {
&self.socket_addr
}
/// Try to send a message to the background task without blocking.
///
/// This first polls the channel for capacity.
/// If the channel is full, the message is returned in `Ok(Err(_))`
/// and the context's waker is registered for wake-up.
///
/// If the background task crashed `Err` is returned.
pub(crate) fn try_send(
&mut self,
to_endpoint: ToEndpoint,
cx: &mut Context<'_>,
) -> Result<Result<(), ToEndpoint>, Disconnected> {
match self.to_endpoint.poll_ready_unpin(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(e)) => {
debug_assert!(
e.is_disconnected(),
"mpsc::Sender can only be disconnected when calling `poll_ready_unpin"
);
return Err(Disconnected {});
}
Poll::Pending => return Ok(Err(to_endpoint)),
};
if let Err(e) = self.to_endpoint.start_send(to_endpoint) {
debug_assert!(e.is_disconnected(), "We called `Sink::poll_ready` so we are guaranteed to have a slot. If this fails, it means we are disconnected.");
return Err(Disconnected {});
}
Ok(Ok(()))
}
pub(crate) async fn send(&mut self, to_endpoint: ToEndpoint) -> Result<(), Disconnected> {
self.to_endpoint
.send(to_endpoint)
.await
.map_err(|_| Disconnected {})
}
/// Send a message to inform the [`Driver`] about an
/// event caused by the owner of this [`Channel`] dropping.
/// This clones the sender to the endpoint to guarantee delivery.
/// This should *not* be called for regular messages.
pub(crate) fn send_on_drop(&mut self, to_endpoint: ToEndpoint) {
let _ = self.to_endpoint.clone().try_send(to_endpoint);
}
}
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
#[error("Background task disconnected")]
pub(crate) struct Disconnected {}
/// Message sent to the endpoint background task.
#[derive(Debug)]
pub(crate) enum ToEndpoint {
/// Instruct the [`quinn_proto::Endpoint`] to start connecting to the given address.
Dial {
/// UDP address to connect to.
addr: SocketAddr,
/// Version to dial the remote on.
version: ProtocolVersion,
/// Channel to return the result of the dialing to.
result: oneshot::Sender<Result<Connection, Error>>,
},
/// Send by a [`quinn_proto::Connection`] when the endpoint needs to process an event generated
/// by a connection. The event itself is opaque to us. Only `quinn_proto` knows what is in
/// there.
ProcessConnectionEvent {
connection_id: quinn_proto::ConnectionHandle,
event: quinn_proto::EndpointEvent,
},
/// Instruct the endpoint to send a packet of data on its UDP socket.
SendUdpPacket(quinn_proto::Transmit),
/// The [`GenTransport`][crate::GenTransport] dialer or listener coupled to this endpoint
/// was dropped.
/// Once all pending connections are closed, the [`Driver`] should shut down.
Decoupled,
}
/// Driver that runs in the background for as long as the endpoint is alive. Responsible for
/// processing messages and the UDP socket.
///
/// # Behaviour
///
/// This background task is responsible for the following:
///
/// - Sending packets on the UDP socket.
/// - Receiving packets from the UDP socket and feed them to the [`quinn_proto::Endpoint`] state
/// machine.
/// - Transmitting events generated by the [`quinn_proto::Endpoint`] to the corresponding
/// [`crate::Connection`].
/// - Receiving messages from the `rx` and processing the requested actions. This includes
/// UDP packets to send and events emitted by the [`crate::Connection`] objects.
/// - Sending new connections on `new_connection_tx`.
///
/// When it comes to channels, there exists three main multi-producer-single-consumer channels
/// in play:
///
/// - One channel, represented by `EndpointChannel::to_endpoint` and `Driver::rx`,
/// that communicates messages from [`Channel`] to the [`Driver`].
/// - One channel for each existing connection that communicates messages from the
/// [`Driver` to that [`crate::Connection`].
/// - One channel for the [`Driver`] to send newly-opened connections to. The receiving
/// side is processed by the [`GenTransport`][crate::GenTransport].
///
///
/// ## Back-pressure
///
/// ### If writing to the UDP socket is blocked
///
/// In order to avoid an unbounded buffering of events, we prioritize sending data on the UDP
/// socket over everything else. Messages from the rest of the application sent through the
/// [`Channel`] are only processed if the UDP socket is ready so that we propagate back-pressure
/// in case of a busy socket. For connections, thus this eventually also back-pressures the
/// `AsyncWrite`on substreams.
///
///
/// ### Back-pressuring the remote if the application is busy
///
/// If the channel to a connection is full because the connection is busy, inbound datagrams
/// for that connection are dropped so that the remote is backpressured.
/// The same applies for new connections if the transport is too busy to received it.
///
///
/// # Shutdown
///
/// The background task shuts down if an [`ToEndpoint::Decoupled`] event was received and the
/// last active connection has drained.
#[derive(Debug)]
pub(crate) struct Driver<P: Provider> {
// The actual QUIC state machine.
endpoint: quinn_proto::Endpoint,
// QuinnConfig for client connections.
client_config: quinn_proto::ClientConfig,
// Copy of the channel to the endpoint driver that is passed to each new connection.
channel: Channel,
// Channel to receive messages from the transport or connections.
rx: mpsc::Receiver<ToEndpoint>,
// Socket for sending and receiving datagrams.
provider_socket: P,
// Future for writing the next packet to the socket.
next_packet_out: Option<quinn_proto::Transmit>,
// List of all active connections, with a sender to notify them of events.
alive_connections:
HashMap<quinn_proto::ConnectionHandle, mpsc::Sender<quinn_proto::ConnectionEvent>>,
// Channel to forward new inbound connections to the transport.
// `None` if server capabilities are disabled, i.e. the endpoint is only used for dialing.
new_connection_tx: Option<mpsc::Sender<Connection>>,
// Whether the transport dropped its handle for this endpoint.
is_decoupled: bool,
}
impl<P: Provider> Driver<P> {
fn new(
endpoint_config: Arc<quinn_proto::EndpointConfig>,
client_config: quinn_proto::ClientConfig,
new_connection_tx: Option<mpsc::Sender<Connection>>,
server_config: Option<Arc<quinn_proto::ServerConfig>>,
channel: Channel,
socket: P,
rx: mpsc::Receiver<ToEndpoint>,
) -> Self {
Driver {
endpoint: quinn_proto::Endpoint::new(endpoint_config, server_config, false),
client_config,
channel,
rx,
provider_socket: socket,
next_packet_out: None,
alive_connections: HashMap::new(),
new_connection_tx,
is_decoupled: false,
}
}
/// Handle a message sent from either the [`GenTransport`](super::GenTransport)
/// or a [`crate::Connection`].
fn handle_message(
&mut self,
to_endpoint: ToEndpoint,
) -> ControlFlow<(), Option<quinn_proto::Transmit>> {
match to_endpoint {
ToEndpoint::Dial {
addr,
result,
version,
} => {
let mut config = self.client_config.clone();
if version == ProtocolVersion::Draft29 {
config.version(0xff00_001d);
}
// This `"l"` seems necessary because an empty string is an invalid domain
// name. While we don't use domain names, the underlying rustls library
// is based upon the assumption that we do.
let (connection_id, connection) = match self.endpoint.connect(config, addr, "l") {
Ok(c) => c,
Err(err) => {
let _ = result.send(Err(ConnectError::from(err).into()));
return ControlFlow::Continue(None);
}
};
debug_assert_eq!(connection.side(), quinn_proto::Side::Client);
let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY);
let connection = Connection::from_quinn_connection(
self.channel.clone(),
connection,
connection_id,
rx,
);
self.alive_connections.insert(connection_id, tx);
let _ = result.send(Ok(connection));
}
// A connection wants to notify the endpoint of something.
ToEndpoint::ProcessConnectionEvent {
connection_id,
event,
} => {
let has_key = self.alive_connections.contains_key(&connection_id);
if !has_key {
return ControlFlow::Continue(None);
}
// We "drained" event indicates that the connection no longer exists and
// its ID can be reclaimed.
let is_drained_event = event.is_drained();
if is_drained_event {
self.alive_connections.remove(&connection_id);
if self.is_decoupled && self.alive_connections.is_empty() {
log::debug!(
"Driver is decoupled and no active connections remain. Shutting down."
);
return ControlFlow::Break(());
}
}
let event_back = self.endpoint.handle_event(connection_id, event);
if let Some(event_back) = event_back {
debug_assert!(!is_drained_event);
if let Some(sender) = self.alive_connections.get_mut(&connection_id) {
// We clone the sender to guarantee that there will be at least one
// free slot to send the event.
// The channel can not grow out of bound because an `event_back` is
// only sent if we previously received an event from the same connection.
// If the connection is busy, it won't sent us any more events to handle.
let _ = sender.clone().start_send(event_back);
} else {
log::error!("State mismatch: event for closed connection");
}
}
}
// Data needs to be sent on the UDP socket.
ToEndpoint::SendUdpPacket(transmit) => return ControlFlow::Continue(Some(transmit)),
ToEndpoint::Decoupled => self.handle_decoupling()?,
}
ControlFlow::Continue(None)
}
/// Handle an UDP datagram received on the socket.
/// The datagram content was written into the `socket_recv_buffer`.
fn handle_datagram(&mut self, packet: BytesMut, packet_src: SocketAddr) -> ControlFlow<()> {
let local_ip = self.channel.socket_addr.ip();
// TODO: ECN bits aren't handled
let (connec_id, event) =
match self
.endpoint
.handle(Instant::now(), packet_src, Some(local_ip), None, packet)
{
Some(event) => event,
None => return ControlFlow::Continue(()),
};
match event {
quinn_proto::DatagramEvent::ConnectionEvent(event) => {
// `event` has type `quinn_proto::ConnectionEvent`, which has multiple
// variants. `quinn_proto::Endpoint::handle` however only ever returns
// `ConnectionEvent::Datagram`.
debug_assert!(format!("{event:?}").contains("Datagram"));
// Redirect the datagram to its connection.
if let Some(sender) = self.alive_connections.get_mut(&connec_id) {
match sender.try_send(event) {
Ok(()) => {}
Err(err) if err.is_disconnected() => {
// Connection was dropped by the user.
// Inform the endpoint that this connection is drained.
self.endpoint
.handle_event(connec_id, quinn_proto::EndpointEvent::drained());
self.alive_connections.remove(&connec_id);
}
Err(err) if err.is_full() => {
// Connection is too busy. Drop the datagram to back-pressure the remote.
log::debug!(
"Dropping packet for connection {:?} because the connection's channel is full.",
connec_id
);
}
Err(_) => unreachable!("Error is either `Full` or `Disconnected`."),
}
} else {
log::error!("State mismatch: event for closed connection");
}
}
quinn_proto::DatagramEvent::NewConnection(connec) => {
// A new connection has been received. `connec_id` is a newly-allocated
// identifier.
debug_assert_eq!(connec.side(), quinn_proto::Side::Server);
let connection_tx = match self.new_connection_tx.as_mut() {
Some(tx) => tx,
None => {
debug_assert!(false, "Endpoint reported a new connection even though server capabilities are disabled.");
return ControlFlow::Continue(());
}
};
let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY);
let connection =
Connection::from_quinn_connection(self.channel.clone(), connec, connec_id, rx);
match connection_tx.start_send(connection) {
Ok(()) => {
self.alive_connections.insert(connec_id, tx);
}
Err(e) if e.is_disconnected() => self.handle_decoupling()?,
Err(e) if e.is_full() => log::warn!(
"Dropping new incoming connection {:?} because the channel to the listener is full",
connec_id
),
Err(_) => unreachable!("Error is either `Full` or `Disconnected`."),
}
}
}
ControlFlow::Continue(())
}
/// The transport dropped the channel to this [`Driver`].
fn handle_decoupling(&mut self) -> ControlFlow<()> {
if self.alive_connections.is_empty() {
return ControlFlow::Break(());
}
// Listener was closed.
self.endpoint.reject_new_connections();
self.new_connection_tx = None;
self.is_decoupled = true;
ControlFlow::Continue(())
}
}
/// Future that runs until the [`Driver`] is decoupled and not active connections
/// remain
impl<P: Provider> Future for Driver<P> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
// Flush any pending pocket so that the socket is reading to write an next
// packet.
match self.provider_socket.poll_send_flush(cx) {
// The pending packet was send or no packet was pending.
Poll::Ready(Ok(_)) => {
// Start sending a packet on the socket.
if let Some(transmit) = self.next_packet_out.take() {
self.provider_socket
.start_send(transmit.contents.into(), transmit.destination);
continue;
}
// The endpoint might request packets to be sent out. This is handled in
// priority to avoid buffering up packets.
if let Some(transmit) = self.endpoint.poll_transmit() {
self.next_packet_out = Some(transmit);
continue;
}
// Handle messages from transport and connections.
match self.rx.poll_next_unpin(cx) {
Poll::Ready(Some(to_endpoint)) => match self.handle_message(to_endpoint) {
ControlFlow::Continue(Some(transmit)) => {
self.next_packet_out = Some(transmit);
continue;
}
ControlFlow::Continue(None) => continue,
ControlFlow::Break(()) => break,
},
Poll::Ready(None) => {
unreachable!("Sender side is never dropped or closed.")
}
Poll::Pending => {}
}
}
// Errors on the socket are expected to never happen, and we handle them by simply
// printing a log message. The packet gets discarded in case of error, but we are
// robust to packet losses and it is consequently not a logic error to proceed with
// normal operations.
Poll::Ready(Err(err)) => {
log::warn!("Error while sending on QUIC UDP socket: {:?}", err);
continue;
}
Poll::Pending => {}
}
// Poll for new packets from the remote.
match self.provider_socket.poll_recv_from(cx) {
Poll::Ready(Ok((bytes, packet_src))) => {
let bytes_mut = bytes.as_slice().into();
match self.handle_datagram(bytes_mut, packet_src) {
ControlFlow::Continue(()) => continue,
ControlFlow::Break(()) => break,
}
}
// Errors on the socket are expected to never happen, and we handle them by
// simply printing a log message.
Poll::Ready(Err(err)) => {
log::warn!("Error while receive on QUIC UDP socket: {:?}", err);
continue;
}
Poll::Pending => {}
}
return Poll::Pending;
}
Poll::Ready(())
}
}

View File

@ -1,19 +1,20 @@
use std::{net::SocketAddr, time::Duration};
use crate::{provider::Provider, Error};
use futures::future::Either;
use rand::{distributions, Rng};
use crate::{
endpoint::{self, ToEndpoint},
Error, Provider,
use std::{
net::{SocketAddr, UdpSocket},
time::Duration,
};
pub(crate) async fn hole_puncher<P: Provider>(
endpoint_channel: endpoint::Channel,
socket: UdpSocket,
remote_addr: SocketAddr,
timeout_duration: Duration,
) -> Error {
let punch_holes_future = punch_holes::<P>(endpoint_channel, remote_addr);
let punch_holes_future = punch_holes::<P>(socket, remote_addr);
futures::pin_mut!(punch_holes_future);
match futures::future::select(P::sleep(timeout_duration), punch_holes_future).await {
Either::Left(_) => Error::HandshakeTimedOut,
@ -21,27 +22,18 @@ pub(crate) async fn hole_puncher<P: Provider>(
}
}
async fn punch_holes<P: Provider>(
mut endpoint_channel: endpoint::Channel,
remote_addr: SocketAddr,
) -> Error {
async fn punch_holes<P: Provider>(socket: UdpSocket, remote_addr: SocketAddr) -> Error {
loop {
let sleep_duration = Duration::from_millis(rand::thread_rng().gen_range(10..=200));
P::sleep(sleep_duration).await;
let random_udp_packet = ToEndpoint::SendUdpPacket(quinn_proto::Transmit {
destination: remote_addr,
ecn: None,
contents: rand::thread_rng()
.sample_iter(distributions::Standard)
.take(64)
.collect(),
segment_size: None,
src_ip: None,
});
let contents: Vec<u8> = rand::thread_rng()
.sample_iter(distributions::Standard)
.take(64)
.collect();
if endpoint_channel.send(random_udp_packet).await.is_err() {
return Error::EndpointDriverCrashed;
if let Err(e) = P::send_to(&socket, &contents, remote_addr).await {
return Error::Io(e);
}
}
}

View File

@ -57,16 +57,17 @@
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
mod config;
mod connection;
mod endpoint;
mod hole_punching;
mod provider;
mod transport;
use std::net::SocketAddr;
pub use connection::{Connecting, Connection, Substream};
pub use endpoint::Config;
pub use config::Config;
pub use connection::{Connecting, Connection, Stream};
#[cfg(feature = "async-std")]
pub use provider::async_std;
#[cfg(feature = "tokio")]
@ -89,8 +90,7 @@ pub enum Error {
#[error(transparent)]
Io(#[from] std::io::Error),
/// The task spawned in [`Provider::spawn`] to drive
/// the quic endpoint has crashed.
/// The task to drive a quic endpoint has crashed.
#[error("Endpoint driver crashed")]
EndpointDriverCrashed,
@ -110,9 +110,9 @@ pub enum Error {
/// Dialing a remote peer failed.
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ConnectError(#[from] quinn_proto::ConnectError);
pub struct ConnectError(quinn::ConnectError);
/// Error on an established [`Connection`].
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct ConnectionError(#[from] quinn_proto::ConnectionError);
pub struct ConnectionError(quinn::ConnectionError);

View File

@ -18,11 +18,11 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::{future::BoxFuture, Future};
use futures::future::BoxFuture;
use if_watch::IfEvent;
use std::{
io,
net::SocketAddr,
net::{SocketAddr, UdpSocket},
task::{Context, Poll},
time::Duration,
};
@ -32,40 +32,20 @@ pub mod async_std;
#[cfg(feature = "tokio")]
pub mod tokio;
/// Size of the buffer for reading data 0x10000.
#[cfg(any(feature = "async-std", feature = "tokio"))]
const RECEIVE_BUFFER_SIZE: usize = 65536;
pub enum Runtime {
#[cfg(feature = "tokio")]
Tokio,
#[cfg(feature = "async-std")]
AsyncStd,
Dummy,
}
/// Provider for non-blocking receiving and sending on a [`std::net::UdpSocket`]
/// and spawning tasks.
/// Provider for a corresponding quinn runtime and spawning tasks.
pub trait Provider: Unpin + Send + Sized + 'static {
type IfWatcher: Unpin + Send;
/// Create a new providing that is wrapping the socket.
///
/// Note: The socket must be set to non-blocking.
fn from_socket(socket: std::net::UdpSocket) -> io::Result<Self>;
/// Receive a single packet.
///
/// Returns the message and the address the message came from.
fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<(Vec<u8>, SocketAddr)>>;
/// Set sending a packet on the socket.
///
/// Since only one packet can be sent at a time, this may only be called if a preceding
/// call to [`Provider::poll_send_flush`] returned [`Poll::Ready`].
fn start_send(&mut self, data: Vec<u8>, addr: SocketAddr);
/// Flush a packet send in [`Provider::start_send`].
///
/// If [`Poll::Ready`] is returned the socket is ready for sending a new packet.
fn poll_send_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>;
/// Run the given future in the background until it ends.
///
/// This is used to spawn the task that is driving the endpoint.
fn spawn(future: impl Future<Output = ()> + Send + 'static);
/// Run the corresponding runtime
fn runtime() -> Runtime;
/// Create a new [`if_watch`] watcher that reports [`IfEvent`]s for network interface changes.
fn new_if_watcher() -> io::Result<Self::IfWatcher>;
@ -78,4 +58,11 @@ pub trait Provider: Unpin + Send + Sized + 'static {
/// Sleep for specified amount of time.
fn sleep(duration: Duration) -> BoxFuture<'static, ()>;
/// Sends data on the socket to the given address. On success, returns the number of bytes written.
fn send_to<'a>(
udp_socket: &'a UdpSocket,
buf: &'a [u8],
target: SocketAddr,
) -> BoxFuture<'a, io::Result<usize>>;
}

View File

@ -18,13 +18,10 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use async_std::{net::UdpSocket, task::spawn};
use futures::{future::BoxFuture, ready, Future, FutureExt, Stream, StreamExt};
use futures::{future::BoxFuture, FutureExt};
use std::{
io,
net::SocketAddr,
pin::Pin,
sync::Arc,
net::UdpSocket,
task::{Context, Poll},
time::Duration,
};
@ -34,65 +31,14 @@ use crate::GenTransport;
/// Transport with [`async-std`] runtime.
pub type Transport = GenTransport<Provider>;
/// Provider for reading / writing to a sockets and spawning
/// tasks using [`async-std`].
pub struct Provider {
socket: Arc<UdpSocket>,
// Future for sending a packet.
// This is needed since [`async_Std::net::UdpSocket`] does not
// provide a poll-style interface for sending a packet.
send_packet: Option<BoxFuture<'static, Result<(), io::Error>>>,
recv_stream: ReceiveStream,
}
/// Provider for quinn runtime and spawning tasks using [`async-std`].
pub struct Provider;
impl super::Provider for Provider {
type IfWatcher = if_watch::smol::IfWatcher;
fn from_socket(socket: std::net::UdpSocket) -> io::Result<Self> {
let socket = Arc::new(socket.into());
let recv_stream = ReceiveStream::new(Arc::clone(&socket));
Ok(Provider {
socket,
send_packet: None,
recv_stream,
})
}
fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<(Vec<u8>, SocketAddr)>> {
match self.recv_stream.poll_next_unpin(cx) {
Poll::Ready(ready) => {
Poll::Ready(ready.expect("ReceiveStream::poll_next never returns None."))
}
Poll::Pending => Poll::Pending,
}
}
fn start_send(&mut self, data: Vec<u8>, addr: SocketAddr) {
let socket = self.socket.clone();
let send = async move {
socket.send_to(&data, addr).await?;
Ok(())
}
.boxed();
self.send_packet = Some(send)
}
fn poll_send_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let pending = match self.send_packet.as_mut() {
Some(pending) => pending,
None => return Poll::Ready(Ok(())),
};
match pending.poll_unpin(cx) {
Poll::Ready(result) => {
self.send_packet = None;
Poll::Ready(result)
}
Poll::Pending => Poll::Pending,
}
}
fn spawn(future: impl Future<Output = ()> + Send + 'static) {
spawn(future);
fn runtime() -> super::Runtime {
super::Runtime::AsyncStd
}
fn new_if_watcher() -> io::Result<Self::IfWatcher> {
@ -109,48 +55,16 @@ impl super::Provider for Provider {
fn sleep(duration: Duration) -> BoxFuture<'static, ()> {
async_std::task::sleep(duration).boxed()
}
}
type ReceiveStreamItem = (
Result<(usize, SocketAddr), io::Error>,
Arc<UdpSocket>,
Vec<u8>,
);
/// Wrapper around the socket to implement `Stream` on it.
struct ReceiveStream {
/// Future for receiving a packet on the socket.
// This is needed since [`async_Std::net::UdpSocket`] does not
// provide a poll-style interface for receiving packets.
fut: BoxFuture<'static, ReceiveStreamItem>,
}
impl ReceiveStream {
fn new(socket: Arc<UdpSocket>) -> Self {
let fut = ReceiveStream::next(socket, vec![0; super::RECEIVE_BUFFER_SIZE]).boxed();
Self { fut: fut.boxed() }
}
async fn next(socket: Arc<UdpSocket>, mut socket_recv_buffer: Vec<u8>) -> ReceiveStreamItem {
let recv = socket.recv_from(&mut socket_recv_buffer).await;
(recv, socket, socket_recv_buffer)
}
}
impl Stream for ReceiveStream {
type Item = Result<(Vec<u8>, SocketAddr), io::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (result, socket, buffer) = ready!(self.fut.poll_unpin(cx));
let result = result.map(|(packet_len, packet_src)| {
debug_assert!(packet_len <= buffer.len());
// Copies the bytes from the `socket_recv_buffer` they were written into.
(buffer[..packet_len].into(), packet_src)
});
// Set the future for receiving the next packet on the stream.
self.fut = ReceiveStream::next(socket, buffer).boxed();
Poll::Ready(Some(result))
fn send_to<'a>(
udp_socket: &'a UdpSocket,
buf: &'a [u8],
target: std::net::SocketAddr,
) -> BoxFuture<'a, io::Result<usize>> {
Box::pin(async move {
async_std::net::UdpSocket::from(udp_socket.try_clone()?)
.send_to(buf, target)
.await
})
}
}

View File

@ -18,72 +18,27 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use futures::{future::BoxFuture, ready, Future, FutureExt};
use futures::{future::BoxFuture, FutureExt};
use std::{
io,
net::SocketAddr,
net::{SocketAddr, UdpSocket},
task::{Context, Poll},
time::Duration,
};
use tokio::{io::ReadBuf, net::UdpSocket};
use crate::GenTransport;
/// Transport with [`tokio`] runtime.
pub type Transport = GenTransport<Provider>;
/// Provider for reading / writing to a sockets and spawning
/// tasks using [`tokio`].
pub struct Provider {
socket: UdpSocket,
socket_recv_buffer: Vec<u8>,
next_packet_out: Option<(Vec<u8>, SocketAddr)>,
}
/// Provider for quinn runtime and spawning tasks using [`tokio`].
pub struct Provider;
impl super::Provider for Provider {
type IfWatcher = if_watch::tokio::IfWatcher;
fn from_socket(socket: std::net::UdpSocket) -> std::io::Result<Self> {
let socket = UdpSocket::from_std(socket)?;
Ok(Provider {
socket,
socket_recv_buffer: vec![0; super::RECEIVE_BUFFER_SIZE],
next_packet_out: None,
})
}
fn poll_send_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
let (data, addr) = match self.next_packet_out.as_ref() {
Some(pending) => pending,
None => return Poll::Ready(Ok(())),
};
match self.socket.poll_send_to(cx, data.as_slice(), *addr) {
Poll::Ready(result) => {
self.next_packet_out = None;
Poll::Ready(result.map(|_| ()))
}
Poll::Pending => Poll::Pending,
}
}
fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<(Vec<u8>, SocketAddr)>> {
let Self {
socket,
socket_recv_buffer,
..
} = self;
let mut read_buf = ReadBuf::new(socket_recv_buffer.as_mut_slice());
let packet_src = ready!(socket.poll_recv_from(cx, &mut read_buf)?);
let bytes = read_buf.filled().to_vec();
Poll::Ready(Ok((bytes, packet_src)))
}
fn start_send(&mut self, data: Vec<u8>, addr: SocketAddr) {
self.next_packet_out = Some((data, addr));
}
fn spawn(future: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(future);
fn runtime() -> super::Runtime {
super::Runtime::Tokio
}
fn new_if_watcher() -> io::Result<Self::IfWatcher> {
@ -100,4 +55,16 @@ impl super::Provider for Provider {
fn sleep(duration: Duration) -> BoxFuture<'static, ()> {
tokio::time::sleep(duration).boxed()
}
fn send_to<'a>(
udp_socket: &'a UdpSocket,
buf: &'a [u8],
target: SocketAddr,
) -> BoxFuture<'a, io::Result<usize>> {
Box::pin(async move {
tokio::net::UdpSocket::from_std(udp_socket.try_clone()?)?
.send_to(buf, target)
.await
})
}
}

View File

@ -18,12 +18,12 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use crate::endpoint::{Config, QuinnConfig, ToEndpoint};
use crate::config::{Config, QuinnConfig};
use crate::hole_punching::hole_puncher;
use crate::provider::Provider;
use crate::{endpoint, Connecting, Connection, Error};
use crate::{ConnectError, Connecting, Connection, Error};
use futures::channel::{mpsc, oneshot};
use futures::channel::oneshot;
use futures::future::{BoxFuture, Either};
use futures::ready;
use futures::stream::StreamExt;
@ -38,10 +38,10 @@ use libp2p_core::{
};
use libp2p_identity::PeerId;
use std::collections::hash_map::{DefaultHasher, Entry};
use std::collections::{HashMap, VecDeque};
use std::collections::HashMap;
use std::fmt;
use std::hash::{Hash, Hasher};
use std::net::IpAddr;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, UdpSocket};
use std::time::Duration;
use std::{
net::SocketAddr,
@ -62,7 +62,7 @@ use std::{
/// See <https://github.com/multiformats/multiaddr/issues/145>.
#[derive(Debug)]
pub struct GenTransport<P: Provider> {
/// Config for the inner [`quinn_proto`] structs.
/// Config for the inner [`quinn`] structs.
quinn_config: QuinnConfig,
/// Timeout for the [`Connecting`] future.
handshake_timeout: Duration,
@ -71,7 +71,7 @@ pub struct GenTransport<P: Provider> {
/// Streams of active [`Listener`]s.
listeners: SelectAll<Listener<P>>,
/// Dialer for each socket family if no matching listener exists.
dialer: HashMap<SocketFamily, Dialer>,
dialer: HashMap<SocketFamily, quinn::Endpoint>,
/// Waker to poll the transport again when a new dialer or listener is added.
waker: Option<Waker>,
/// Holepunching attempts
@ -95,21 +95,57 @@ impl<P: Provider> GenTransport<P> {
}
}
/// Create a new [`quinn::Endpoint`] with the given configs.
fn new_endpoint(
endpoint_config: quinn::EndpointConfig,
server_config: Option<quinn::ServerConfig>,
socket: UdpSocket,
) -> Result<quinn::Endpoint, Error> {
use crate::provider::Runtime;
match P::runtime() {
#[cfg(feature = "tokio")]
Runtime::Tokio => {
let runtime = std::sync::Arc::new(quinn::TokioRuntime);
let endpoint =
quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
Ok(endpoint)
}
#[cfg(feature = "async-std")]
Runtime::AsyncStd => {
let runtime = std::sync::Arc::new(quinn::AsyncStdRuntime);
let endpoint =
quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
Ok(endpoint)
}
Runtime::Dummy => {
let _ = endpoint_config;
let _ = server_config;
let _ = socket;
let err = std::io::Error::new(std::io::ErrorKind::Other, "no async runtime found");
Err(Error::Io(err))
}
}
}
/// Extract the addr, quic version and peer id from the given [`Multiaddr`].
fn remote_multiaddr_to_socketaddr(
&self,
addr: Multiaddr,
check_unspecified_addr: bool,
) -> Result<
(SocketAddr, ProtocolVersion, Option<PeerId>),
TransportError<<Self as Transport>::Error>,
> {
let (socket_addr, version, peer_id) = multiaddr_to_socketaddr(&addr, self.support_draft_29)
.ok_or_else(|| TransportError::MultiaddrNotSupported(addr.clone()))?;
if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() {
if check_unspecified_addr && (socket_addr.port() == 0 || socket_addr.ip().is_unspecified())
{
return Err(TransportError::MultiaddrNotSupported(addr));
}
Ok((socket_addr, version, peer_id))
}
/// Pick any listener to use for dialing.
fn eligible_listener(&mut self, socket_addr: &SocketAddr) -> Option<&mut Listener<P>> {
let mut listeners: Vec<_> = self
.listeners
@ -118,7 +154,7 @@ impl<P: Provider> GenTransport<P> {
if l.is_closed {
return false;
}
let listen_addr = l.endpoint_channel.socket_addr();
let listen_addr = l.socket_addr();
SocketFamily::is_same(&listen_addr.ip(), &socket_addr.ip())
&& listen_addr.ip().is_loopback() == socket_addr.ip().is_loopback()
})
@ -149,13 +185,16 @@ impl<P: Provider> Transport for GenTransport<P> {
listener_id: ListenerId,
addr: Multiaddr,
) -> Result<(), TransportError<Self::Error>> {
let (socket_addr, version, _peer_id) =
multiaddr_to_socketaddr(&addr, self.support_draft_29)
.ok_or(TransportError::MultiaddrNotSupported(addr))?;
let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr, false)?;
let endpoint_config = self.quinn_config.endpoint_config.clone();
let server_config = self.quinn_config.server_config.clone();
let socket = UdpSocket::bind(socket_addr).map_err(Self::Error::from)?;
let socket_c = socket.try_clone().map_err(Self::Error::from)?;
let endpoint = Self::new_endpoint(endpoint_config, Some(server_config), socket)?;
let listener = Listener::new(
listener_id,
socket_addr,
self.quinn_config.clone(),
socket_c,
endpoint,
self.handshake_timeout,
version,
)?;
@ -194,46 +233,68 @@ impl<P: Provider> Transport for GenTransport<P> {
}
fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr)?;
let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr, true)?;
let handshake_timeout = self.handshake_timeout;
let dialer_state = match self.eligible_listener(&socket_addr) {
let endpoint = match self.eligible_listener(&socket_addr) {
None => {
// No listener. Get or create an explicit dialer.
let socket_family = socket_addr.ip().into();
let dialer = match self.dialer.entry(socket_family) {
Entry::Occupied(occupied) => occupied.into_mut(),
Entry::Occupied(occupied) => occupied.get().clone(),
Entry::Vacant(vacant) => {
if let Some(waker) = self.waker.take() {
waker.wake();
}
vacant.insert(Dialer::new::<P>(self.quinn_config.clone(), socket_family)?)
let listen_socket_addr = match socket_family {
SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
};
let socket =
UdpSocket::bind(listen_socket_addr).map_err(Self::Error::from)?;
let endpoint_config = self.quinn_config.endpoint_config.clone();
let endpoint = Self::new_endpoint(endpoint_config, None, socket)?;
vacant.insert(endpoint.clone());
endpoint
}
};
&mut dialer.state
dialer
}
Some(listener) => &mut listener.dialer_state,
Some(listener) => listener.endpoint.clone(),
};
Ok(dialer_state.new_dial(socket_addr, handshake_timeout, version))
let handshake_timeout = self.handshake_timeout;
let mut client_config = self.quinn_config.client_config.clone();
if version == ProtocolVersion::Draft29 {
client_config.version(0xff00_001d);
}
Ok(Box::pin(async move {
// This `"l"` seems necessary because an empty string is an invalid domain
// name. While we don't use domain names, the underlying rustls library
// is based upon the assumption that we do.
let connecting = endpoint
.connect_with(client_config, socket_addr, "l")
.map_err(ConnectError)?;
Connecting::new(connecting, handshake_timeout).await
}))
}
fn dial_as_listener(
&mut self,
addr: Multiaddr,
) -> Result<Self::Dial, TransportError<Self::Error>> {
let (socket_addr, _version, peer_id) = self.remote_multiaddr_to_socketaddr(addr.clone())?;
let (socket_addr, _version, peer_id) =
self.remote_multiaddr_to_socketaddr(addr.clone(), true)?;
let peer_id = peer_id.ok_or(TransportError::MultiaddrNotSupported(addr))?;
let endpoint_channel = self
let socket = self
.eligible_listener(&socket_addr)
.ok_or(TransportError::Other(
Error::NoActiveListenerForDialAsListener,
))?
.endpoint_channel
.clone();
.try_clone_socket()
.map_err(Self::Error::from)?;
let hole_puncher = hole_puncher::<P>(endpoint_channel, socket_addr, self.handshake_timeout);
let hole_puncher = hole_puncher::<P>(socket, socket_addr, self.handshake_timeout);
let (sender, receiver) = oneshot::channel();
@ -274,19 +335,6 @@ impl<P: Provider> Transport for GenTransport<P> {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
let mut errored = Vec::new();
for (key, dialer) in &mut self.dialer {
if let Poll::Ready(_error) = dialer.poll(cx) {
errored.push(*key);
}
}
for key in errored {
// Endpoint driver of dialer crashed.
// Drop dialer and all pending dials so that the connection receiver is notified.
self.dialer.remove(&key);
}
while let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) {
match ev {
TransportEvent::Incoming {
@ -331,112 +379,22 @@ impl From<Error> for TransportError<Error> {
}
}
/// Dialer for addresses if no matching listener exists.
#[derive(Debug)]
struct Dialer {
/// Channel to the [`crate::endpoint::Driver`] that
/// is driving the endpoint.
endpoint_channel: endpoint::Channel,
/// Queued dials for the endpoint.
state: DialerState,
}
impl Dialer {
fn new<P: Provider>(
config: QuinnConfig,
socket_family: SocketFamily,
) -> Result<Self, TransportError<Error>> {
let endpoint_channel = endpoint::Channel::new_dialer::<P>(config, socket_family)
.map_err(TransportError::Other)?;
Ok(Dialer {
endpoint_channel,
state: DialerState::default(),
})
}
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Error> {
self.state.poll(&mut self.endpoint_channel, cx)
}
}
impl Drop for Dialer {
fn drop(&mut self) {
self.endpoint_channel.send_on_drop(ToEndpoint::Decoupled);
}
}
/// Pending dials to be sent to the endpoint was the [`endpoint::Channel`]
/// has capacity
#[derive(Default, Debug)]
struct DialerState {
pending_dials: VecDeque<ToEndpoint>,
waker: Option<Waker>,
}
impl DialerState {
fn new_dial(
&mut self,
address: SocketAddr,
timeout: Duration,
version: ProtocolVersion,
) -> BoxFuture<'static, Result<(PeerId, Connection), Error>> {
let (rx, tx) = oneshot::channel();
let message = ToEndpoint::Dial {
addr: address,
result: rx,
version,
};
self.pending_dials.push_back(message);
if let Some(waker) = self.waker.take() {
waker.wake();
}
async move {
// Our oneshot getting dropped means the message didn't make it to the endpoint driver.
let connection = tx.await.map_err(|_| Error::EndpointDriverCrashed)??;
let (peer, connection) = Connecting::new(connection, timeout).await?;
Ok((peer, connection))
}
.boxed()
}
/// Send all pending dials into the given [`endpoint::Channel`].
///
/// This only ever returns [`Poll::Pending`], or an error in case the channel is closed.
fn poll(&mut self, channel: &mut endpoint::Channel, cx: &mut Context<'_>) -> Poll<Error> {
while let Some(to_endpoint) = self.pending_dials.pop_front() {
match channel.try_send(to_endpoint, cx) {
Ok(Ok(())) => {}
Ok(Err(to_endpoint)) => {
self.pending_dials.push_front(to_endpoint);
break;
}
Err(endpoint::Disconnected {}) => return Poll::Ready(Error::EndpointDriverCrashed),
}
}
self.waker = Some(cx.waker().clone());
Poll::Pending
}
}
/// Listener for incoming connections.
struct Listener<P: Provider> {
/// Id of the listener.
listener_id: ListenerId,
/// Version of the supported quic protocol.
version: ProtocolVersion,
/// Channel to the endpoint to initiate dials.
endpoint_channel: endpoint::Channel,
/// Queued dials.
dialer_state: DialerState,
/// Endpoint
endpoint: quinn::Endpoint,
/// Channel where new connections are being sent.
new_connections_rx: mpsc::Receiver<Connection>,
/// An underlying copy of the socket to be able to hole punch with
socket: UdpSocket,
/// A future to poll new incoming connections.
accept: BoxFuture<'static, Option<quinn::Connecting>>,
/// Timeout for connection establishment on inbound connections.
handshake_timeout: Duration,
@ -458,38 +416,39 @@ struct Listener<P: Provider> {
impl<P: Provider> Listener<P> {
fn new(
listener_id: ListenerId,
socket_addr: SocketAddr,
config: QuinnConfig,
socket: UdpSocket,
endpoint: quinn::Endpoint,
handshake_timeout: Duration,
version: ProtocolVersion,
) -> Result<Self, Error> {
let (endpoint_channel, new_connections_rx) =
endpoint::Channel::new_bidirectional::<P>(config, socket_addr)?;
let if_watcher;
let pending_event;
if socket_addr.ip().is_unspecified() {
let local_addr = socket.local_addr()?;
if local_addr.ip().is_unspecified() {
if_watcher = Some(P::new_if_watcher()?);
pending_event = None;
} else {
if_watcher = None;
let ma = socketaddr_to_multiaddr(endpoint_channel.socket_addr(), version);
let ma = socketaddr_to_multiaddr(&local_addr, version);
pending_event = Some(TransportEvent::NewAddress {
listener_id,
listen_addr: ma,
})
}
let endpoint_c = endpoint.clone();
let accept = async move { endpoint_c.accept().await }.boxed();
Ok(Listener {
endpoint_channel,
endpoint,
socket,
accept,
listener_id,
version,
new_connections_rx,
handshake_timeout,
if_watcher,
is_closed: false,
pending_event,
dialer_state: DialerState::default(),
close_listener_waker: None,
})
}
@ -500,6 +459,7 @@ impl<P: Provider> Listener<P> {
if self.is_closed {
return;
}
self.endpoint.close(From::from(0u32), &[]);
self.pending_event = Some(TransportEvent::ListenerClosed {
listener_id: self.listener_id,
reason,
@ -512,8 +472,20 @@ impl<P: Provider> Listener<P> {
}
}
/// Clone underlying socket (for hole punching).
fn try_clone_socket(&self) -> std::io::Result<UdpSocket> {
self.socket.try_clone()
}
fn socket_addr(&self) -> SocketAddr {
self.socket
.local_addr()
.expect("Cannot fail because the socket is bound")
}
/// Poll for a next If Event.
fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
let endpoint_addr = self.socket_addr();
let if_watcher = match self.if_watcher.as_mut() {
Some(iw) => iw,
None => return Poll::Pending,
@ -521,11 +493,9 @@ impl<P: Provider> Listener<P> {
loop {
match ready!(P::poll_if_event(if_watcher, cx)) {
Ok(IfEvent::Up(inet)) => {
if let Some(listen_addr) = ip_to_listenaddr(
self.endpoint_channel.socket_addr(),
inet.addr(),
self.version,
) {
if let Some(listen_addr) =
ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
{
log::debug!("New listen address: {}", listen_addr);
return Poll::Ready(TransportEvent::NewAddress {
listener_id: self.listener_id,
@ -534,11 +504,9 @@ impl<P: Provider> Listener<P> {
}
}
Ok(IfEvent::Down(inet)) => {
if let Some(listen_addr) = ip_to_listenaddr(
self.endpoint_channel.socket_addr(),
inet.addr(),
self.version,
) {
if let Some(listen_addr) =
ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
{
log::debug!("Expired listen address: {}", listen_addr);
return Poll::Ready(TransportEvent::AddressExpired {
listener_id: self.listener_id,
@ -555,21 +523,10 @@ impl<P: Provider> Listener<P> {
}
}
}
/// Poll [`DialerState`] to initiate requested dials.
fn poll_dialer(&mut self, cx: &mut Context<'_>) -> Poll<Error> {
let Self {
dialer_state,
endpoint_channel,
..
} = self;
dialer_state.poll(endpoint_channel, cx)
}
}
impl<P: Provider> Stream for Listener<P> {
type Item = TransportEvent<Connecting, Error>;
type Item = TransportEvent<<GenTransport<P> as Transport>::ListenerUpgrade, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
if let Some(event) = self.pending_event.take() {
@ -581,17 +538,18 @@ impl<P: Provider> Stream for Listener<P> {
if let Poll::Ready(event) = self.poll_if_addr(cx) {
return Poll::Ready(Some(event));
}
if let Poll::Ready(error) = self.poll_dialer(cx) {
self.close(Err(error));
continue;
}
match self.new_connections_rx.poll_next_unpin(cx) {
Poll::Ready(Some(connection)) => {
let local_addr = socketaddr_to_multiaddr(connection.local_addr(), self.version);
let send_back_addr =
socketaddr_to_multiaddr(&connection.remote_addr(), self.version);
match self.accept.poll_unpin(cx) {
Poll::Ready(Some(connecting)) => {
let endpoint = self.endpoint.clone();
self.accept = async move { endpoint.accept().await }.boxed();
let local_addr = socketaddr_to_multiaddr(&self.socket_addr(), self.version);
let remote_addr = connecting.remote_address();
let send_back_addr = socketaddr_to_multiaddr(&remote_addr, self.version);
let event = TransportEvent::Incoming {
upgrade: Connecting::new(connection, self.handshake_timeout),
upgrade: Connecting::new(connecting, self.handshake_timeout),
local_addr,
send_back_addr,
listener_id: self.listener_id,
@ -616,9 +574,6 @@ impl<P: Provider> fmt::Debug for Listener<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Listener")
.field("listener_id", &self.listener_id)
.field("endpoint_channel", &self.endpoint_channel)
.field("dialer_state", &self.dialer_state)
.field("new_connections_rx", &self.new_connections_rx)
.field("handshake_timeout", &self.handshake_timeout)
.field("is_closed", &self.is_closed)
.field("pending_event", &self.pending_event)
@ -626,12 +581,6 @@ impl<P: Provider> fmt::Debug for Listener<P> {
}
}
impl<P: Provider> Drop for Listener<P> {
fn drop(&mut self) {
self.endpoint_channel.send_on_drop(ToEndpoint::Decoupled);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum ProtocolVersion {
V1, // i.e. RFC9000
@ -766,7 +715,6 @@ fn socketaddr_to_multiaddr(socket_addr: &SocketAddr, version: ProtocolVersion) -
#[cfg(any(feature = "async-std", feature = "tokio"))]
mod test {
use futures::future::poll_fn;
use futures_timer::Delay;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use super::*;
@ -882,15 +830,6 @@ mod test {
.listen_on(id, "/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap())
.unwrap();
// Copy channel to use it later.
let mut channel = transport
.listeners
.iter()
.next()
.unwrap()
.endpoint_channel
.clone();
match poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx)).await {
TransportEvent::NewAddress {
listener_id,
@ -923,14 +862,6 @@ mod test {
.now_or_never()
.is_none());
assert!(transport.listeners.is_empty());
// Check that the [`Driver`] has shut down.
Delay::new(Duration::from_millis(10)).await;
poll_fn(|cx| {
assert!(channel.try_send(ToEndpoint::Decoupled, cx).is_err());
Poll::Ready(())
})
.await;
}
}
@ -945,32 +876,9 @@ mod test {
.dial("/ip4/123.45.67.8/udp/1234/quic-v1".parse().unwrap())
.unwrap();
// Expect a dialer and its background task to exist.
let mut channel = transport
.dialer
.get(&SocketFamily::Ipv4)
.unwrap()
.endpoint_channel
.clone();
assert!(transport.dialer.contains_key(&SocketFamily::Ipv4));
assert!(!transport.dialer.contains_key(&SocketFamily::Ipv6));
// Send dummy dial to check that the endpoint driver is running.
poll_fn(|cx| {
let (tx, _) = oneshot::channel();
let _ = channel
.try_send(
ToEndpoint::Dial {
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
result: tx,
version: ProtocolVersion::V1,
},
cx,
)
.unwrap();
Poll::Ready(())
})
.await;
// Start listening so that the dialer and driver are dropped.
transport
.listen_on(
@ -979,13 +887,5 @@ mod test {
)
.unwrap();
assert!(!transport.dialer.contains_key(&SocketFamily::Ipv4));
// Check that the [`Driver`] has shut down.
Delay::new(Duration::from_millis(10)).await;
poll_fn(|cx| {
assert!(channel.try_send(ToEndpoint::Decoupled, cx).is_err());
Poll::Ready(())
})
.await;
}
}

View File

@ -428,7 +428,7 @@ async fn smoke<P: Provider>() {
assert_eq!(b_connected, a_peer_id);
}
async fn build_streams<P: Provider>() -> (SubstreamBox, SubstreamBox) {
async fn build_streams<P: Provider + Spawn>() -> (SubstreamBox, SubstreamBox) {
let (_, mut a_transport) = create_default_transport::<P>();
let (_, mut b_transport) = create_default_transport::<P>();
@ -522,7 +522,7 @@ async fn start_listening(transport: &mut Boxed<(PeerId, StreamMuxerBox)>, addr:
}
}
fn prop<P: Provider + BlockOn>(
fn prop<P: Provider + BlockOn + Spawn>(
number_listeners: NonZeroU8,
number_streams: NonZeroU8,
) -> quickcheck::TestResult {
@ -599,7 +599,7 @@ fn prop<P: Provider + BlockOn>(
quickcheck::TestResult::passed()
}
async fn answer_inbound_streams<P: Provider, const BUFFER_SIZE: usize>(
async fn answer_inbound_streams<P: Provider + Spawn, const BUFFER_SIZE: usize>(
mut connection: StreamMuxerBox,
) {
loop {
@ -634,7 +634,7 @@ async fn answer_inbound_streams<P: Provider, const BUFFER_SIZE: usize>(
}
}
async fn open_outbound_streams<P: Provider, const BUFFER_SIZE: usize>(
async fn open_outbound_streams<P: Provider + Spawn, const BUFFER_SIZE: usize>(
mut connection: StreamMuxerBox,
number_streams: usize,
completed_streams_tx: mpsc::Sender<()>,
@ -740,3 +740,22 @@ impl BlockOn for libp2p_quic::tokio::Provider {
.unwrap()
}
}
trait Spawn {
/// Run the given future in the background until it ends.
fn spawn(future: impl Future<Output = ()> + Send + 'static);
}
#[cfg(feature = "async-std")]
impl Spawn for libp2p_quic::async_std::Provider {
fn spawn(future: impl Future<Output = ()> + Send + 'static) {
async_std::task::spawn(future);
}
}
#[cfg(feature = "tokio")]
impl Spawn for libp2p_quic::tokio::Provider {
fn spawn(future: impl Future<Output = ()> + Send + 'static) {
tokio::spawn(future);
}
}