mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-06-30 10:11:33 +00:00
refactor(quic): rewrite quic using quinn
Rewrite quic using quinn instead of quinn-proto. libp2p-quic::endpoint::Driver is eliminated (and that hard quinn-proto machinery). Also: - ECN bits are handled - Support Generic Send Offload (GSO) Pull-Request: #3454.
This commit is contained in:
37
Cargo.lock
generated
37
Cargo.lock
generated
@ -3040,7 +3040,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "libp2p-quic"
|
||||
version = "0.8.0-alpha"
|
||||
version = "0.9.0-alpha"
|
||||
dependencies = [
|
||||
"async-std",
|
||||
"bytes",
|
||||
@ -3058,7 +3058,7 @@ dependencies = [
|
||||
"log",
|
||||
"parking_lot",
|
||||
"quickcheck",
|
||||
"quinn-proto",
|
||||
"quinn",
|
||||
"rand 0.8.5",
|
||||
"rustls 0.21.2",
|
||||
"thiserror",
|
||||
@ -4313,6 +4313,26 @@ dependencies = [
|
||||
"pin-project-lite 0.1.12",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn"
|
||||
version = "0.10.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "21252f1c0fc131f1b69182db8f34837e8a69737b8251dff75636a9be0518c324"
|
||||
dependencies = [
|
||||
"async-io",
|
||||
"async-std",
|
||||
"bytes",
|
||||
"futures-io",
|
||||
"pin-project-lite 0.2.9",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustls 0.21.2",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-proto"
|
||||
version = "0.10.1"
|
||||
@ -4330,6 +4350,19 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quinn-udp"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6df19e284d93757a9fb91d63672f7741b129246a669db09d1c0063071debc0c0"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"libc",
|
||||
"socket2 0.5.3",
|
||||
"tracing",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.32"
|
||||
|
@ -83,7 +83,7 @@ libp2p-perf = { version = "0.2.0", path = "protocols/perf" }
|
||||
libp2p-ping = { version = "0.43.0", path = "protocols/ping" }
|
||||
libp2p-plaintext = { version = "0.40.0", path = "transports/plaintext" }
|
||||
libp2p-pnet = { version = "0.23.0", path = "transports/pnet" }
|
||||
libp2p-quic = { version = "0.8.0-alpha", path = "transports/quic" }
|
||||
libp2p-quic = { version = "0.9.0-alpha", path = "transports/quic" }
|
||||
libp2p-relay = { version = "0.16.1", path = "protocols/relay" }
|
||||
libp2p-rendezvous = { version = "0.13.0", path = "protocols/rendezvous" }
|
||||
libp2p-request-response = { version = "0.25.1", path = "protocols/request-response" }
|
||||
|
@ -1,3 +1,10 @@
|
||||
## 0.9.0-alpha - unreleased
|
||||
|
||||
- Use `quinn` instead of `quinn-proto`.
|
||||
See [PR 3454].
|
||||
|
||||
[PR 3454]: https://github.com/libp2p/rust-libp2p/pull/3454
|
||||
|
||||
## 0.8.0-alpha
|
||||
|
||||
- Raise MSRV to 1.65.
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "libp2p-quic"
|
||||
version = "0.8.0-alpha"
|
||||
version = "0.9.0-alpha"
|
||||
authors = ["Parity Technologies <admin@parity.io>"]
|
||||
edition = "2021"
|
||||
rust-version = { workspace = true }
|
||||
@ -19,15 +19,15 @@ libp2p-tls = { workspace = true }
|
||||
libp2p-identity = { workspace = true }
|
||||
log = "0.4"
|
||||
parking_lot = "0.12.0"
|
||||
quinn-proto = { version = "0.10.1", default-features = false, features = ["tls-rustls"] }
|
||||
quinn = { version = "0.10.1", default-features = false, features = ["tls-rustls", "futures-io"] }
|
||||
rand = "0.8.5"
|
||||
rustls = { version = "0.21.2", default-features = false }
|
||||
thiserror = "1.0.44"
|
||||
tokio = { version = "1.29.1", default-features = false, features = ["net", "rt", "time"], optional = true }
|
||||
|
||||
[features]
|
||||
tokio = ["dep:tokio", "if-watch/tokio"]
|
||||
async-std = ["dep:async-std", "if-watch/smol"]
|
||||
tokio = ["dep:tokio", "if-watch/tokio", "quinn/runtime-tokio"]
|
||||
async-std = ["dep:async-std", "if-watch/smol", "quinn/runtime-async-std"]
|
||||
|
||||
# Passing arguments to the docsrs builder in order to properly document cfg's.
|
||||
# More information: https://docs.rs/about/builds#cross-compiling
|
||||
|
142
transports/quic/src/config.rs
Normal file
142
transports/quic/src/config.rs
Normal file
@ -0,0 +1,142 @@
|
||||
// Copyright 2017-2020 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use quinn::VarInt;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
/// Config for the transport.
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
/// Timeout for the initial handshake when establishing a connection.
|
||||
/// The actual timeout is the minimum of this and the [`Config::max_idle_timeout`].
|
||||
pub handshake_timeout: Duration,
|
||||
/// Maximum duration of inactivity in ms to accept before timing out the connection.
|
||||
pub max_idle_timeout: u32,
|
||||
/// Period of inactivity before sending a keep-alive packet.
|
||||
/// Must be set lower than the idle_timeout of both
|
||||
/// peers to be effective.
|
||||
///
|
||||
/// See [`quinn::TransportConfig::keep_alive_interval`] for more
|
||||
/// info.
|
||||
pub keep_alive_interval: Duration,
|
||||
/// Maximum number of incoming bidirectional streams that may be open
|
||||
/// concurrently by the remote peer.
|
||||
pub max_concurrent_stream_limit: u32,
|
||||
|
||||
/// Max unacknowledged data in bytes that may be send on a single stream.
|
||||
pub max_stream_data: u32,
|
||||
|
||||
/// Max unacknowledged data in bytes that may be send in total on all streams
|
||||
/// of a connection.
|
||||
pub max_connection_data: u32,
|
||||
|
||||
/// Support QUIC version draft-29 for dialing and listening.
|
||||
///
|
||||
/// Per default only QUIC Version 1 / [`libp2p_core::multiaddr::Protocol::QuicV1`]
|
||||
/// is supported.
|
||||
///
|
||||
/// If support for draft-29 is enabled servers support draft-29 and version 1 on all
|
||||
/// QUIC listening addresses.
|
||||
/// As client the version is chosen based on the remote's address.
|
||||
pub support_draft_29: bool,
|
||||
|
||||
/// TLS client config for the inner [`quinn::ClientConfig`].
|
||||
client_tls_config: Arc<rustls::ClientConfig>,
|
||||
/// TLS server config for the inner [`quinn::ServerConfig`].
|
||||
server_tls_config: Arc<rustls::ServerConfig>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Creates a new configuration object with default values.
|
||||
pub fn new(keypair: &libp2p_identity::Keypair) -> Self {
|
||||
let client_tls_config = Arc::new(libp2p_tls::make_client_config(keypair, None).unwrap());
|
||||
let server_tls_config = Arc::new(libp2p_tls::make_server_config(keypair).unwrap());
|
||||
Self {
|
||||
client_tls_config,
|
||||
server_tls_config,
|
||||
support_draft_29: false,
|
||||
handshake_timeout: Duration::from_secs(5),
|
||||
max_idle_timeout: 30 * 1000,
|
||||
max_concurrent_stream_limit: 256,
|
||||
keep_alive_interval: Duration::from_secs(15),
|
||||
max_connection_data: 15_000_000,
|
||||
|
||||
// Ensure that one stream is not consuming the whole connection.
|
||||
max_stream_data: 10_000_000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the inner configuration for [`quinn`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct QuinnConfig {
|
||||
pub(crate) client_config: quinn::ClientConfig,
|
||||
pub(crate) server_config: quinn::ServerConfig,
|
||||
pub(crate) endpoint_config: quinn::EndpointConfig,
|
||||
}
|
||||
|
||||
impl From<Config> for QuinnConfig {
|
||||
fn from(config: Config) -> QuinnConfig {
|
||||
let Config {
|
||||
client_tls_config,
|
||||
server_tls_config,
|
||||
max_idle_timeout,
|
||||
max_concurrent_stream_limit,
|
||||
keep_alive_interval,
|
||||
max_connection_data,
|
||||
max_stream_data,
|
||||
support_draft_29,
|
||||
handshake_timeout: _,
|
||||
} = config;
|
||||
let mut transport = quinn::TransportConfig::default();
|
||||
// Disable uni-directional streams.
|
||||
transport.max_concurrent_uni_streams(0u32.into());
|
||||
transport.max_concurrent_bidi_streams(max_concurrent_stream_limit.into());
|
||||
// Disable datagrams.
|
||||
transport.datagram_receive_buffer_size(None);
|
||||
transport.keep_alive_interval(Some(keep_alive_interval));
|
||||
transport.max_idle_timeout(Some(VarInt::from_u32(max_idle_timeout).into()));
|
||||
transport.allow_spin(false);
|
||||
transport.stream_receive_window(max_stream_data.into());
|
||||
transport.receive_window(max_connection_data.into());
|
||||
let transport = Arc::new(transport);
|
||||
|
||||
let mut server_config = quinn::ServerConfig::with_crypto(server_tls_config);
|
||||
server_config.transport = Arc::clone(&transport);
|
||||
// Disables connection migration.
|
||||
// Long-term this should be enabled, however we then need to handle address change
|
||||
// on connections in the `Connection`.
|
||||
server_config.migration(false);
|
||||
|
||||
let mut client_config = quinn::ClientConfig::new(client_tls_config);
|
||||
client_config.transport_config(transport);
|
||||
|
||||
let mut endpoint_config = quinn::EndpointConfig::default();
|
||||
if !support_draft_29 {
|
||||
endpoint_config.supported_versions(vec![1]);
|
||||
}
|
||||
|
||||
QuinnConfig {
|
||||
client_config,
|
||||
server_config,
|
||||
endpoint_config,
|
||||
}
|
||||
}
|
||||
}
|
@ -19,409 +19,113 @@
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
mod connecting;
|
||||
mod substream;
|
||||
mod stream;
|
||||
|
||||
use crate::{
|
||||
endpoint::{self, ToEndpoint},
|
||||
Error,
|
||||
};
|
||||
pub use connecting::Connecting;
|
||||
pub use substream::Substream;
|
||||
use substream::{SubstreamState, WriteState};
|
||||
pub use stream::Stream;
|
||||
|
||||
use futures::{channel::mpsc, ready, FutureExt, StreamExt};
|
||||
use futures_timer::Delay;
|
||||
use crate::{ConnectionError, Error};
|
||||
|
||||
use futures::{future::BoxFuture, FutureExt};
|
||||
use libp2p_core::muxing::{StreamMuxer, StreamMuxerEvent};
|
||||
use parking_lot::Mutex;
|
||||
use std::{
|
||||
any::Any,
|
||||
collections::HashMap,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll, Waker},
|
||||
time::Instant,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
/// State for a single opened QUIC connection.
|
||||
#[derive(Debug)]
|
||||
pub struct Connection {
|
||||
/// State shared with the substreams.
|
||||
state: Arc<Mutex<State>>,
|
||||
/// Channel to the [`endpoint::Driver`] that drives the [`quinn_proto::Endpoint`] that
|
||||
/// this connection belongs to.
|
||||
endpoint_channel: endpoint::Channel,
|
||||
/// Pending message to be sent to the [`quinn_proto::Endpoint`] in the [`endpoint::Driver`].
|
||||
pending_to_endpoint: Option<ToEndpoint>,
|
||||
/// Events that the [`quinn_proto::Endpoint`] will send in destination to our local
|
||||
/// [`quinn_proto::Connection`].
|
||||
from_endpoint: mpsc::Receiver<quinn_proto::ConnectionEvent>,
|
||||
/// Identifier for this connection according to the [`quinn_proto::Endpoint`].
|
||||
/// Used when sending messages to the endpoint.
|
||||
connection_id: quinn_proto::ConnectionHandle,
|
||||
/// `Future` that triggers at the [`Instant`] that [`quinn_proto::Connection::poll_timeout`]
|
||||
/// indicates.
|
||||
next_timeout: Option<(Delay, Instant)>,
|
||||
/// Underlying connection.
|
||||
connection: quinn::Connection,
|
||||
/// Future for accepting a new incoming bidirectional stream.
|
||||
incoming: Option<
|
||||
BoxFuture<'static, Result<(quinn::SendStream, quinn::RecvStream), quinn::ConnectionError>>,
|
||||
>,
|
||||
/// Future for opening a new outgoing bidirectional stream.
|
||||
outgoing: Option<
|
||||
BoxFuture<'static, Result<(quinn::SendStream, quinn::RecvStream), quinn::ConnectionError>>,
|
||||
>,
|
||||
/// Future to wait for the connection to be closed.
|
||||
closing: Option<BoxFuture<'static, quinn::ConnectionError>>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Build a [`Connection`] from raw components.
|
||||
///
|
||||
/// This function assumes that there exists a [`Driver`](super::endpoint::Driver)
|
||||
/// that will process the messages sent to `EndpointChannel::to_endpoint` and send us messages
|
||||
/// on `from_endpoint`.
|
||||
///
|
||||
/// `connection_id` is used to identify the local connection in the messages sent to
|
||||
/// `to_endpoint`.
|
||||
///
|
||||
/// This function assumes that the [`quinn_proto::Connection`] is completely fresh and none of
|
||||
/// This function assumes that the [`quinn::Connection`] is completely fresh and none of
|
||||
/// its methods has ever been called. Failure to comply might lead to logic errors and panics.
|
||||
pub(crate) fn from_quinn_connection(
|
||||
endpoint_channel: endpoint::Channel,
|
||||
connection: quinn_proto::Connection,
|
||||
connection_id: quinn_proto::ConnectionHandle,
|
||||
from_endpoint: mpsc::Receiver<quinn_proto::ConnectionEvent>,
|
||||
) -> Self {
|
||||
let state = State {
|
||||
connection,
|
||||
substreams: HashMap::new(),
|
||||
poll_connection_waker: None,
|
||||
poll_inbound_waker: None,
|
||||
poll_outbound_waker: None,
|
||||
};
|
||||
fn new(connection: quinn::Connection) -> Self {
|
||||
Self {
|
||||
endpoint_channel,
|
||||
pending_to_endpoint: None,
|
||||
next_timeout: None,
|
||||
from_endpoint,
|
||||
connection_id,
|
||||
state: Arc::new(Mutex::new(state)),
|
||||
}
|
||||
}
|
||||
|
||||
/// The address that the local socket is bound to.
|
||||
pub(crate) fn local_addr(&self) -> &SocketAddr {
|
||||
self.endpoint_channel.socket_addr()
|
||||
}
|
||||
|
||||
/// Returns the address of the node we're connected to.
|
||||
pub(crate) fn remote_addr(&self) -> SocketAddr {
|
||||
self.state.lock().connection.remote_address()
|
||||
}
|
||||
|
||||
/// Identity of the remote peer inferred from the handshake.
|
||||
///
|
||||
/// `None` if the handshake is not complete yet, i.e. [`Self::poll_event`]
|
||||
/// has not yet reported a [`quinn_proto::Event::Connected`]
|
||||
fn peer_identity(&self) -> Option<Box<dyn Any>> {
|
||||
self.state
|
||||
.lock()
|
||||
.connection
|
||||
.crypto_session()
|
||||
.peer_identity()
|
||||
}
|
||||
|
||||
/// Polls the connection for an event that happened on it.
|
||||
///
|
||||
/// `quinn::proto::Connection` is polled in the order instructed in their docs:
|
||||
/// 1. [`quinn_proto::Connection::poll_transmit`]
|
||||
/// 2. [`quinn_proto::Connection::poll_timeout`]
|
||||
/// 3. [`quinn_proto::Connection::poll_endpoint_events`]
|
||||
/// 4. [`quinn_proto::Connection::poll`]
|
||||
fn poll_event(&mut self, cx: &mut Context<'_>) -> Poll<Option<quinn_proto::Event>> {
|
||||
let mut inner = self.state.lock();
|
||||
loop {
|
||||
// Sending the pending event to the endpoint. If the endpoint is too busy, we just
|
||||
// stop the processing here.
|
||||
// We don't deliver substream-related events to the user as long as
|
||||
// `to_endpoint` is full. This should propagate the back-pressure of `to_endpoint`
|
||||
// being full to the user.
|
||||
if let Some(to_endpoint) = self.pending_to_endpoint.take() {
|
||||
match self.endpoint_channel.try_send(to_endpoint, cx) {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(to_endpoint)) => {
|
||||
self.pending_to_endpoint = Some(to_endpoint);
|
||||
return Poll::Pending;
|
||||
}
|
||||
Err(endpoint::Disconnected {}) => {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match self.from_endpoint.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(event)) => {
|
||||
inner.connection.handle_event(event);
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
return Poll::Ready(None);
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// The maximum amount of segments which can be transmitted in a single Transmit
|
||||
// if a platform supports Generic Send Offload (GSO).
|
||||
// Set to 1 for now since not all platforms support GSO.
|
||||
// TODO: Fix for platforms that support GSO.
|
||||
let max_datagrams = 1;
|
||||
// Poll the connection for packets to send on the UDP socket and try to send them on
|
||||
// `to_endpoint`.
|
||||
if let Some(transmit) = inner
|
||||
.connection
|
||||
.poll_transmit(Instant::now(), max_datagrams)
|
||||
{
|
||||
// TODO: ECN bits not handled
|
||||
self.pending_to_endpoint = Some(ToEndpoint::SendUdpPacket(transmit));
|
||||
continue;
|
||||
}
|
||||
|
||||
match inner.connection.poll_timeout() {
|
||||
Some(timeout) => match self.next_timeout {
|
||||
Some((_, when)) if when == timeout => {}
|
||||
_ => {
|
||||
let now = Instant::now();
|
||||
// 0ns if now > when
|
||||
let duration = timeout.duration_since(now);
|
||||
let next_timeout = Delay::new(duration);
|
||||
self.next_timeout = Some((next_timeout, timeout))
|
||||
}
|
||||
},
|
||||
None => self.next_timeout = None,
|
||||
}
|
||||
|
||||
if let Some((timeout, when)) = self.next_timeout.as_mut() {
|
||||
if timeout.poll_unpin(cx).is_ready() {
|
||||
inner.connection.handle_timeout(*when);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// The connection also needs to be able to send control messages to the endpoint. This is
|
||||
// handled here, and we try to send them on `to_endpoint` as well.
|
||||
if let Some(event) = inner.connection.poll_endpoint_events() {
|
||||
let connection_id = self.connection_id;
|
||||
self.pending_to_endpoint = Some(ToEndpoint::ProcessConnectionEvent {
|
||||
connection_id,
|
||||
event,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// The final step consists in returning the events related to the various substreams.
|
||||
if let Some(ev) = inner.connection.poll() {
|
||||
return Poll::Ready(Some(ev));
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
connection,
|
||||
incoming: None,
|
||||
outgoing: None,
|
||||
closing: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamMuxer for Connection {
|
||||
type Substream = Substream;
|
||||
type Substream = Stream;
|
||||
type Error = Error;
|
||||
|
||||
fn poll(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
|
||||
while let Poll::Ready(event) = self.poll_event(cx) {
|
||||
let mut inner = self.state.lock();
|
||||
let event = match event {
|
||||
Some(event) => event,
|
||||
None => return Poll::Ready(Err(Error::EndpointDriverCrashed)),
|
||||
};
|
||||
match event {
|
||||
quinn_proto::Event::Connected | quinn_proto::Event::HandshakeDataReady => {
|
||||
debug_assert!(
|
||||
false,
|
||||
"Unexpected event {event:?} on established QUIC connection"
|
||||
);
|
||||
}
|
||||
quinn_proto::Event::ConnectionLost { reason } => {
|
||||
inner
|
||||
.connection
|
||||
.close(Instant::now(), From::from(0u32), Default::default());
|
||||
inner.substreams.values_mut().for_each(|s| s.wake_all());
|
||||
return Poll::Ready(Err(Error::Connection(reason.into())));
|
||||
}
|
||||
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Opened {
|
||||
dir: quinn_proto::Dir::Bi,
|
||||
}) => {
|
||||
if let Some(waker) = inner.poll_outbound_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Available {
|
||||
dir: quinn_proto::Dir::Bi,
|
||||
}) => {
|
||||
if let Some(waker) = inner.poll_inbound_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Readable { id }) => {
|
||||
if let Some(substream) = inner.substreams.get_mut(&id) {
|
||||
if let Some(waker) = substream.read_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Writable { id }) => {
|
||||
if let Some(substream) = inner.substreams.get_mut(&id) {
|
||||
if let Some(waker) = substream.write_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Finished { id }) => {
|
||||
if let Some(substream) = inner.substreams.get_mut(&id) {
|
||||
if matches!(
|
||||
substream.write_state,
|
||||
WriteState::Open | WriteState::Closing
|
||||
) {
|
||||
substream.write_state = WriteState::Closed;
|
||||
}
|
||||
if let Some(waker) = substream.write_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
if let Some(waker) = substream.close_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
quinn_proto::Event::Stream(quinn_proto::StreamEvent::Stopped {
|
||||
id,
|
||||
error_code: _,
|
||||
}) => {
|
||||
if let Some(substream) = inner.substreams.get_mut(&id) {
|
||||
substream.write_state = WriteState::Stopped;
|
||||
if let Some(waker) = substream.write_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
if let Some(waker) = substream.close_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
quinn_proto::Event::DatagramReceived
|
||||
| quinn_proto::Event::Stream(quinn_proto::StreamEvent::Available {
|
||||
dir: quinn_proto::Dir::Uni,
|
||||
})
|
||||
| quinn_proto::Event::Stream(quinn_proto::StreamEvent::Opened {
|
||||
dir: quinn_proto::Dir::Uni,
|
||||
}) => {
|
||||
unreachable!("We don't use datagrams or unidirectional streams.")
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: If connection migration is enabled (currently disabled) address
|
||||
// change on the connection needs to be handled.
|
||||
|
||||
self.state.lock().poll_connection_waker = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_inbound(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Self::Substream, Self::Error>> {
|
||||
let mut inner = self.state.lock();
|
||||
let this = self.get_mut();
|
||||
|
||||
let substream_id = match inner.connection.streams().accept(quinn_proto::Dir::Bi) {
|
||||
Some(id) => {
|
||||
inner.poll_inbound_waker = None;
|
||||
id
|
||||
}
|
||||
None => {
|
||||
inner.poll_inbound_waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
};
|
||||
inner.substreams.insert(substream_id, Default::default());
|
||||
let substream = Substream::new(substream_id, self.state.clone());
|
||||
let incoming = this.incoming.get_or_insert_with(|| {
|
||||
let connection = this.connection.clone();
|
||||
async move { connection.accept_bi().await }.boxed()
|
||||
});
|
||||
|
||||
Poll::Ready(Ok(substream))
|
||||
let (send, recv) = futures::ready!(incoming.poll_unpin(cx)).map_err(ConnectionError)?;
|
||||
this.incoming.take();
|
||||
let stream = Stream::new(send, recv);
|
||||
Poll::Ready(Ok(stream))
|
||||
}
|
||||
|
||||
fn poll_outbound(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<Result<Self::Substream, Self::Error>> {
|
||||
let mut inner = self.state.lock();
|
||||
let substream_id = match inner.connection.streams().open(quinn_proto::Dir::Bi) {
|
||||
Some(id) => {
|
||||
inner.poll_outbound_waker = None;
|
||||
id
|
||||
}
|
||||
None => {
|
||||
inner.poll_outbound_waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
let this = self.get_mut();
|
||||
|
||||
let outgoing = this.outgoing.get_or_insert_with(|| {
|
||||
let connection = this.connection.clone();
|
||||
async move { connection.open_bi().await }.boxed()
|
||||
});
|
||||
|
||||
let (send, recv) = futures::ready!(outgoing.poll_unpin(cx)).map_err(ConnectionError)?;
|
||||
this.outgoing.take();
|
||||
let stream = Stream::new(send, recv);
|
||||
Poll::Ready(Ok(stream))
|
||||
}
|
||||
|
||||
fn poll(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut Context<'_>,
|
||||
) -> Poll<Result<StreamMuxerEvent, Self::Error>> {
|
||||
// TODO: If connection migration is enabled (currently disabled) address
|
||||
// change on the connection needs to be handled.
|
||||
Poll::Pending
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
let this = self.get_mut();
|
||||
|
||||
let closing = this.closing.get_or_insert_with(|| {
|
||||
this.connection.close(From::from(0u32), &[]);
|
||||
let connection = this.connection.clone();
|
||||
async move { connection.closed().await }.boxed()
|
||||
});
|
||||
|
||||
match futures::ready!(closing.poll_unpin(cx)) {
|
||||
// Expected error given that `connection.close` was called above.
|
||||
quinn::ConnectionError::LocallyClosed => {}
|
||||
error => return Poll::Ready(Err(Error::Connection(ConnectionError(error)))),
|
||||
};
|
||||
inner.substreams.insert(substream_id, Default::default());
|
||||
let substream = Substream::new(substream_id, self.state.clone());
|
||||
Poll::Ready(Ok(substream))
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
let mut inner = self.state.lock();
|
||||
if inner.connection.is_drained() {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
for substream in inner.substreams.keys().cloned().collect::<Vec<_>>() {
|
||||
let _ = inner.connection.send_stream(substream).finish();
|
||||
}
|
||||
|
||||
if inner.connection.streams().send_streams() == 0 && !inner.connection.is_closed() {
|
||||
inner
|
||||
.connection
|
||||
.close(Instant::now(), From::from(0u32), Default::default())
|
||||
}
|
||||
drop(inner);
|
||||
|
||||
loop {
|
||||
match ready!(self.poll_event(cx)) {
|
||||
Some(quinn_proto::Event::ConnectionLost { .. }) => return Poll::Ready(Ok(())),
|
||||
None => return Poll::Ready(Err(Error::EndpointDriverCrashed)),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Connection {
|
||||
fn drop(&mut self) {
|
||||
let to_endpoint = ToEndpoint::ProcessConnectionEvent {
|
||||
connection_id: self.connection_id,
|
||||
event: quinn_proto::EndpointEvent::drained(),
|
||||
};
|
||||
self.endpoint_channel.send_on_drop(to_endpoint);
|
||||
}
|
||||
}
|
||||
|
||||
/// Mutex-protected state of [`Connection`].
|
||||
#[derive(Debug)]
|
||||
pub struct State {
|
||||
/// The QUIC inner state machine for this specific connection.
|
||||
connection: quinn_proto::Connection,
|
||||
|
||||
/// State of all the substreams that the muxer reports as open.
|
||||
pub substreams: HashMap<quinn_proto::StreamId, SubstreamState>,
|
||||
|
||||
/// Waker to wake if a new outbound substream is opened.
|
||||
pub poll_outbound_waker: Option<Waker>,
|
||||
/// Waker to wake if a new inbound substream was happened.
|
||||
pub poll_inbound_waker: Option<Waker>,
|
||||
/// Waker to wake if the connection should be polled again.
|
||||
pub poll_connection_waker: Option<Waker>,
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn unchecked_substream_state(&mut self, id: quinn_proto::StreamId) -> &mut SubstreamState {
|
||||
self.substreams
|
||||
.get_mut(&id)
|
||||
.expect("Substream should be known.")
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
@ -20,9 +20,12 @@
|
||||
|
||||
//! Future that drives a QUIC connection until is has performed its TLS handshake.
|
||||
|
||||
use crate::{Connection, Error};
|
||||
use crate::{Connection, ConnectionError, Error};
|
||||
|
||||
use futures::prelude::*;
|
||||
use futures::{
|
||||
future::{select, Either, FutureExt, Select},
|
||||
prelude::*,
|
||||
};
|
||||
use futures_timer::Delay;
|
||||
use libp2p_identity::PeerId;
|
||||
use std::{
|
||||
@ -34,64 +37,46 @@ use std::{
|
||||
/// A QUIC connection currently being negotiated.
|
||||
#[derive(Debug)]
|
||||
pub struct Connecting {
|
||||
connection: Option<Connection>,
|
||||
timeout: Delay,
|
||||
connecting: Select<quinn::Connecting, Delay>,
|
||||
}
|
||||
|
||||
impl Connecting {
|
||||
pub(crate) fn new(connection: Connection, timeout: Duration) -> Self {
|
||||
pub(crate) fn new(connection: quinn::Connecting, timeout: Duration) -> Self {
|
||||
Connecting {
|
||||
connection: Some(connection),
|
||||
timeout: Delay::new(timeout),
|
||||
connecting: select(connection, Delay::new(timeout)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Connecting {
|
||||
/// Returns the address of the node we're connected to.
|
||||
/// Panics if the connection is still handshaking.
|
||||
fn remote_peer_id(connection: &quinn::Connection) -> PeerId {
|
||||
let identity = connection
|
||||
.peer_identity()
|
||||
.expect("connection got identity because it passed TLS handshake; qed");
|
||||
let certificates: Box<Vec<rustls::Certificate>> =
|
||||
identity.downcast().expect("we rely on rustls feature; qed");
|
||||
let end_entity = certificates
|
||||
.get(0)
|
||||
.expect("there should be exactly one certificate; qed");
|
||||
let p2p_cert = libp2p_tls::certificate::parse(end_entity)
|
||||
.expect("the certificate was validated during TLS handshake; qed");
|
||||
p2p_cert.peer_id()
|
||||
}
|
||||
}
|
||||
|
||||
impl Future for Connecting {
|
||||
type Output = Result<(PeerId, Connection), Error>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let connection = self
|
||||
.connection
|
||||
.as_mut()
|
||||
.expect("Future polled after it has completed");
|
||||
let connection = match futures::ready!(self.connecting.poll_unpin(cx)) {
|
||||
Either::Right(_) => return Poll::Ready(Err(Error::HandshakeTimedOut)),
|
||||
Either::Left((connection, _)) => connection.map_err(ConnectionError)?,
|
||||
};
|
||||
|
||||
loop {
|
||||
let event = match connection.poll_event(cx) {
|
||||
Poll::Ready(Some(event)) => event,
|
||||
Poll::Ready(None) => return Poll::Ready(Err(Error::EndpointDriverCrashed)),
|
||||
Poll::Pending => {
|
||||
return self
|
||||
.timeout
|
||||
.poll_unpin(cx)
|
||||
.map(|()| Err(Error::HandshakeTimedOut));
|
||||
}
|
||||
};
|
||||
match event {
|
||||
quinn_proto::Event::Connected => {
|
||||
// Parse the remote's Id identity from the certificate.
|
||||
let identity = connection
|
||||
.peer_identity()
|
||||
.expect("connection got identity because it passed TLS handshake; qed");
|
||||
let certificates: Box<Vec<rustls::Certificate>> =
|
||||
identity.downcast().expect("we rely on rustls feature; qed");
|
||||
let end_entity = certificates
|
||||
.get(0)
|
||||
.expect("there should be exactly one certificate; qed");
|
||||
let p2p_cert = libp2p_tls::certificate::parse(end_entity)
|
||||
.expect("the certificate was validated during TLS handshake; qed");
|
||||
let peer_id = p2p_cert.peer_id();
|
||||
|
||||
return Poll::Ready(Ok((peer_id, self.connection.take().unwrap())));
|
||||
}
|
||||
quinn_proto::Event::ConnectionLost { reason } => {
|
||||
return Poll::Ready(Err(Error::Connection(reason.into())))
|
||||
}
|
||||
quinn_proto::Event::HandshakeDataReady | quinn_proto::Event::Stream(_) => {}
|
||||
quinn_proto::Event::DatagramReceived => {
|
||||
debug_assert!(false, "Datagrams are not supported")
|
||||
}
|
||||
}
|
||||
}
|
||||
let peer_id = Self::remote_peer_id(&connection);
|
||||
let muxer = Connection::new(connection);
|
||||
Poll::Ready(Ok((peer_id, muxer)))
|
||||
}
|
||||
}
|
||||
|
86
transports/quic/src/connection/stream.rs
Normal file
86
transports/quic/src/connection/stream.rs
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright 2022 Protocol Labs.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use std::{
|
||||
io::{self},
|
||||
pin::Pin,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
|
||||
/// A single stream on a connection
|
||||
pub struct Stream {
|
||||
/// A send part of the stream
|
||||
send: quinn::SendStream,
|
||||
/// A receive part of the stream
|
||||
recv: quinn::RecvStream,
|
||||
/// Whether the stream is closed or not
|
||||
close_result: Option<Result<(), io::ErrorKind>>,
|
||||
}
|
||||
|
||||
impl Stream {
|
||||
pub(super) fn new(send: quinn::SendStream, recv: quinn::RecvStream) -> Self {
|
||||
Self {
|
||||
send,
|
||||
recv,
|
||||
close_result: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Stream {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
if let Some(close_result) = self.close_result {
|
||||
if close_result.is_err() {
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
}
|
||||
Pin::new(&mut self.recv).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Stream {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
Pin::new(&mut self.send).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
Pin::new(&mut self.send).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||
if let Some(close_result) = self.close_result {
|
||||
// For some reason poll_close needs to be 'fuse'able
|
||||
return Poll::Ready(close_result.map_err(Into::into));
|
||||
}
|
||||
let close_result = futures::ready!(Pin::new(&mut self.send).poll_close(cx));
|
||||
self.close_result = Some(close_result.as_ref().map_err(|e| e.kind()).copied());
|
||||
Poll::Ready(close_result)
|
||||
}
|
||||
}
|
@ -1,257 +0,0 @@
|
||||
// Copyright 2022 Protocol Labs.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use std::{
|
||||
io::{self, Write},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll, Waker},
|
||||
};
|
||||
|
||||
use futures::{AsyncRead, AsyncWrite};
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use super::State;
|
||||
|
||||
/// Wakers for the [`AsyncRead`] and [`AsyncWrite`] on a substream.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct SubstreamState {
|
||||
/// Waker to wake if the substream becomes readable.
|
||||
pub read_waker: Option<Waker>,
|
||||
/// Waker to wake if the substream becomes writable, closed or stopped.
|
||||
pub write_waker: Option<Waker>,
|
||||
/// Waker to wake if the substream becomes closed or stopped.
|
||||
pub close_waker: Option<Waker>,
|
||||
|
||||
pub write_state: WriteState,
|
||||
}
|
||||
|
||||
impl SubstreamState {
|
||||
/// Wake all wakers for reading, writing and closed the stream.
|
||||
pub fn wake_all(&mut self) {
|
||||
if let Some(waker) = self.read_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
if let Some(waker) = self.write_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
if let Some(waker) = self.close_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single stream on a connection
|
||||
#[derive(Debug)]
|
||||
pub struct Substream {
|
||||
/// The id of the stream.
|
||||
id: quinn_proto::StreamId,
|
||||
/// The state of the [`super::Connection`] this stream belongs to.
|
||||
state: Arc<Mutex<State>>,
|
||||
}
|
||||
|
||||
impl Substream {
|
||||
pub fn new(id: quinn_proto::StreamId, state: Arc<Mutex<State>>) -> Self {
|
||||
Self { id, state }
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for Substream {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
mut buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let mut state = self.state.lock();
|
||||
|
||||
let mut stream = state.connection.recv_stream(self.id);
|
||||
let mut chunks = match stream.read(true) {
|
||||
Ok(chunks) => chunks,
|
||||
Err(quinn_proto::ReadableError::UnknownStream) => {
|
||||
return Poll::Ready(Ok(0));
|
||||
}
|
||||
Err(quinn_proto::ReadableError::IllegalOrderedRead) => {
|
||||
unreachable!(
|
||||
"Illegal ordered read can only happen if `stream.read(false)` is used."
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let mut bytes = 0;
|
||||
let mut pending = false;
|
||||
let mut error = None;
|
||||
loop {
|
||||
if buf.is_empty() {
|
||||
// Chunks::next will continue returning `Ok(Some(_))` with an
|
||||
// empty chunk if there is no space left in the buffer, so we
|
||||
// break early here.
|
||||
break;
|
||||
}
|
||||
let chunk = match chunks.next(buf.len()) {
|
||||
Ok(Some(chunk)) => chunk,
|
||||
Ok(None) => break,
|
||||
Err(err @ quinn_proto::ReadError::Reset(_)) => {
|
||||
error = Some(Err(io::Error::new(io::ErrorKind::ConnectionReset, err)));
|
||||
break;
|
||||
}
|
||||
Err(quinn_proto::ReadError::Blocked) => {
|
||||
pending = true;
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
buf.write_all(&chunk.bytes).expect("enough buffer space");
|
||||
bytes += chunk.bytes.len();
|
||||
}
|
||||
if chunks.finalize().should_transmit() {
|
||||
if let Some(waker) = state.poll_connection_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
if let Some(err) = error {
|
||||
return Poll::Ready(err);
|
||||
}
|
||||
|
||||
if pending && bytes == 0 {
|
||||
let substream_state = state.unchecked_substream_state(self.id);
|
||||
substream_state.read_waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(bytes))
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for Substream {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
let mut state = self.state.lock();
|
||||
|
||||
match state.connection.send_stream(self.id).write(buf) {
|
||||
Ok(bytes) => {
|
||||
if let Some(waker) = state.poll_connection_waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
Poll::Ready(Ok(bytes))
|
||||
}
|
||||
Err(quinn_proto::WriteError::Blocked) => {
|
||||
let substream_state = state.unchecked_substream_state(self.id);
|
||||
substream_state.write_waker = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
Err(err @ quinn_proto::WriteError::Stopped(_)) => {
|
||||
Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, err)))
|
||||
}
|
||||
Err(quinn_proto::WriteError::UnknownStream) => {
|
||||
Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
// quinn doesn't support flushing, calling close will flush all substreams.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
let mut inner = self.state.lock();
|
||||
|
||||
let substream_state = inner.unchecked_substream_state(self.id);
|
||||
match substream_state.write_state {
|
||||
WriteState::Open => {}
|
||||
WriteState::Closing => {
|
||||
substream_state.close_waker = Some(cx.waker().clone());
|
||||
return Poll::Pending;
|
||||
}
|
||||
WriteState::Closed => return Poll::Ready(Ok(())),
|
||||
WriteState::Stopped => {
|
||||
let err = quinn_proto::FinishError::Stopped(0u32.into());
|
||||
return Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, err)));
|
||||
}
|
||||
}
|
||||
|
||||
match inner.connection.send_stream(self.id).finish() {
|
||||
Ok(()) => {
|
||||
let substream_state = inner.unchecked_substream_state(self.id);
|
||||
substream_state.close_waker = Some(cx.waker().clone());
|
||||
substream_state.write_state = WriteState::Closing;
|
||||
Poll::Pending
|
||||
}
|
||||
Err(err @ quinn_proto::FinishError::Stopped(_)) => {
|
||||
Poll::Ready(Err(io::Error::new(io::ErrorKind::ConnectionReset, err)))
|
||||
}
|
||||
Err(quinn_proto::FinishError::UnknownStream) => {
|
||||
// We never make up IDs so the stream must have existed at some point if we get to here.
|
||||
// `UnknownStream` is also emitted in case the stream is already finished, hence just
|
||||
// return `Ok(())` here.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Substream {
|
||||
fn drop(&mut self) {
|
||||
let mut state = self.state.lock();
|
||||
state.substreams.remove(&self.id);
|
||||
// Send `STOP_STREAM` if the remote did not finish the stream yet.
|
||||
// We have to manually check the read stream since we might have
|
||||
// received a `FIN` (without any other stream data) after the last
|
||||
// time we tried to read.
|
||||
let mut is_read_done = false;
|
||||
if let Ok(mut chunks) = state.connection.recv_stream(self.id).read(true) {
|
||||
if let Ok(chunk) = chunks.next(0) {
|
||||
is_read_done = chunk.is_none();
|
||||
}
|
||||
let _ = chunks.finalize();
|
||||
}
|
||||
if !is_read_done {
|
||||
let _ = state.connection.recv_stream(self.id).stop(0u32.into());
|
||||
}
|
||||
// Close the writing side.
|
||||
let mut send_stream = state.connection.send_stream(self.id);
|
||||
match send_stream.finish() {
|
||||
Ok(()) => {}
|
||||
// Already finished or reset, which is fine.
|
||||
Err(quinn_proto::FinishError::UnknownStream) => {}
|
||||
Err(quinn_proto::FinishError::Stopped(reason)) => {
|
||||
let _ = send_stream.reset(reason);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub enum WriteState {
|
||||
/// The stream is open for writing.
|
||||
#[default]
|
||||
Open,
|
||||
/// The writing side of the stream is closing.
|
||||
Closing,
|
||||
/// All data was successfully sent to the remote and the stream closed,
|
||||
/// i.e. a [`quinn_proto::StreamEvent::Finished`] was reported for it.
|
||||
Closed,
|
||||
/// The stream was stopped by the remote before all data could be
|
||||
/// sent.
|
||||
Stopped,
|
||||
}
|
@ -1,674 +0,0 @@
|
||||
// Copyright 2017-2020 Parity Technologies (UK) Ltd.
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a
|
||||
// copy of this software and associated documentation files (the "Software"),
|
||||
// to deal in the Software without restriction, including without limitation
|
||||
// the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
// and/or sell copies of the Software, and to permit persons to whom the
|
||||
// Software is furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in
|
||||
// all copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use crate::{
|
||||
provider::Provider,
|
||||
transport::{ProtocolVersion, SocketFamily},
|
||||
ConnectError, Connection, Error,
|
||||
};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use futures::{
|
||||
channel::{mpsc, oneshot},
|
||||
prelude::*,
|
||||
};
|
||||
use quinn_proto::VarInt;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||
ops::ControlFlow,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
// The `Driver` drops packets if the channel to the connection
|
||||
// or transport is full.
|
||||
// Set capacity 10 to avoid unnecessary packet drops if the receiver
|
||||
// is only very briefly busy, but not buffer a large amount of packets
|
||||
// if it is blocked longer.
|
||||
const CHANNEL_CAPACITY: usize = 10;
|
||||
|
||||
/// Config for the transport.
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
/// Timeout for the initial handshake when establishing a connection.
|
||||
/// The actual timeout is the minimum of this an the [`Config::max_idle_timeout`].
|
||||
pub handshake_timeout: Duration,
|
||||
/// Maximum duration of inactivity in ms to accept before timing out the connection.
|
||||
pub max_idle_timeout: u32,
|
||||
/// Period of inactivity before sending a keep-alive packet.
|
||||
/// Must be set lower than the idle_timeout of both
|
||||
/// peers to be effective.
|
||||
///
|
||||
/// See [`quinn_proto::TransportConfig::keep_alive_interval`] for more
|
||||
/// info.
|
||||
pub keep_alive_interval: Duration,
|
||||
/// Maximum number of incoming bidirectional streams that may be open
|
||||
/// concurrently by the remote peer.
|
||||
pub max_concurrent_stream_limit: u32,
|
||||
|
||||
/// Max unacknowledged data in bytes that may be send on a single stream.
|
||||
pub max_stream_data: u32,
|
||||
|
||||
/// Max unacknowledged data in bytes that may be send in total on all streams
|
||||
/// of a connection.
|
||||
pub max_connection_data: u32,
|
||||
|
||||
/// Support QUIC version draft-29 for dialing and listening.
|
||||
///
|
||||
/// Per default only QUIC Version 1 / [`libp2p_core::multiaddr::Protocol::QuicV1`]
|
||||
/// is supported.
|
||||
///
|
||||
/// If support for draft-29 is enabled servers support draft-29 and version 1 on all
|
||||
/// QUIC listening addresses.
|
||||
/// As client the version is chosen based on the remote's address.
|
||||
pub support_draft_29: bool,
|
||||
|
||||
/// TLS client config for the inner [`quinn_proto::ClientConfig`].
|
||||
client_tls_config: Arc<rustls::ClientConfig>,
|
||||
/// TLS server config for the inner [`quinn_proto::ServerConfig`].
|
||||
server_tls_config: Arc<rustls::ServerConfig>,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Creates a new configuration object with default values.
|
||||
pub fn new(keypair: &libp2p_identity::Keypair) -> Self {
|
||||
let client_tls_config = Arc::new(libp2p_tls::make_client_config(keypair, None).unwrap());
|
||||
let server_tls_config = Arc::new(libp2p_tls::make_server_config(keypair).unwrap());
|
||||
Self {
|
||||
client_tls_config,
|
||||
server_tls_config,
|
||||
support_draft_29: false,
|
||||
handshake_timeout: Duration::from_secs(5),
|
||||
max_idle_timeout: 30 * 1000,
|
||||
max_concurrent_stream_limit: 256,
|
||||
keep_alive_interval: Duration::from_secs(15),
|
||||
max_connection_data: 15_000_000,
|
||||
|
||||
// Ensure that one stream is not consuming the whole connection.
|
||||
max_stream_data: 10_000_000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents the inner configuration for [`quinn_proto`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct QuinnConfig {
|
||||
client_config: quinn_proto::ClientConfig,
|
||||
server_config: Arc<quinn_proto::ServerConfig>,
|
||||
endpoint_config: Arc<quinn_proto::EndpointConfig>,
|
||||
}
|
||||
|
||||
impl From<Config> for QuinnConfig {
|
||||
fn from(config: Config) -> QuinnConfig {
|
||||
let Config {
|
||||
client_tls_config,
|
||||
server_tls_config,
|
||||
max_idle_timeout,
|
||||
max_concurrent_stream_limit,
|
||||
keep_alive_interval,
|
||||
max_connection_data,
|
||||
max_stream_data,
|
||||
support_draft_29,
|
||||
handshake_timeout: _,
|
||||
} = config;
|
||||
let mut transport = quinn_proto::TransportConfig::default();
|
||||
// Disable uni-directional streams.
|
||||
transport.max_concurrent_uni_streams(0u32.into());
|
||||
transport.max_concurrent_bidi_streams(max_concurrent_stream_limit.into());
|
||||
// Disable datagrams.
|
||||
transport.datagram_receive_buffer_size(None);
|
||||
transport.keep_alive_interval(Some(keep_alive_interval));
|
||||
transport.max_idle_timeout(Some(VarInt::from_u32(max_idle_timeout).into()));
|
||||
transport.allow_spin(false);
|
||||
transport.stream_receive_window(max_stream_data.into());
|
||||
transport.receive_window(max_connection_data.into());
|
||||
let transport = Arc::new(transport);
|
||||
|
||||
let mut server_config = quinn_proto::ServerConfig::with_crypto(server_tls_config);
|
||||
server_config.transport = Arc::clone(&transport);
|
||||
// Disables connection migration.
|
||||
// Long-term this should be enabled, however we then need to handle address change
|
||||
// on connections in the `Connection`.
|
||||
server_config.migration(false);
|
||||
|
||||
let mut client_config = quinn_proto::ClientConfig::new(client_tls_config);
|
||||
client_config.transport_config(transport);
|
||||
|
||||
let mut endpoint_config = quinn_proto::EndpointConfig::default();
|
||||
if !support_draft_29 {
|
||||
endpoint_config.supported_versions(vec![1]);
|
||||
}
|
||||
|
||||
QuinnConfig {
|
||||
client_config,
|
||||
server_config: Arc::new(server_config),
|
||||
endpoint_config: Arc::new(endpoint_config),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Channel used to send commands to the [`Driver`].
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Channel {
|
||||
/// Channel to the background of the endpoint.
|
||||
to_endpoint: mpsc::Sender<ToEndpoint>,
|
||||
/// Address that the socket is bound to.
|
||||
/// Note: this may be a wildcard ip address.
|
||||
socket_addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl Channel {
|
||||
/// Builds a new endpoint that is listening on the [`SocketAddr`].
|
||||
pub(crate) fn new_bidirectional<P: Provider>(
|
||||
quinn_config: QuinnConfig,
|
||||
socket_addr: SocketAddr,
|
||||
) -> Result<(Self, mpsc::Receiver<Connection>), Error> {
|
||||
// Channel for forwarding new inbound connections to the listener.
|
||||
let (new_connections_tx, new_connections_rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let endpoint = Self::new::<P>(quinn_config, socket_addr, Some(new_connections_tx))?;
|
||||
Ok((endpoint, new_connections_rx))
|
||||
}
|
||||
|
||||
/// Builds a new endpoint that only supports outbound connections.
|
||||
pub(crate) fn new_dialer<P: Provider>(
|
||||
quinn_config: QuinnConfig,
|
||||
socket_family: SocketFamily,
|
||||
) -> Result<Self, Error> {
|
||||
let socket_addr = match socket_family {
|
||||
SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
|
||||
SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
|
||||
};
|
||||
Self::new::<P>(quinn_config, socket_addr, None)
|
||||
}
|
||||
|
||||
/// Spawn a new [`Driver`] that runs in the background.
|
||||
fn new<P: Provider>(
|
||||
quinn_config: QuinnConfig,
|
||||
socket_addr: SocketAddr,
|
||||
new_connections: Option<mpsc::Sender<Connection>>,
|
||||
) -> Result<Self, Error> {
|
||||
let socket = std::net::UdpSocket::bind(socket_addr)?;
|
||||
// NOT blocking, as per man:bind(2), as we pass an IP address.
|
||||
socket.set_nonblocking(true)?;
|
||||
// Capacity 0 to back-pressure the rest of the application if
|
||||
// the udp socket is busy.
|
||||
let (to_endpoint_tx, to_endpoint_rx) = mpsc::channel(0);
|
||||
|
||||
let channel = Self {
|
||||
to_endpoint: to_endpoint_tx,
|
||||
socket_addr: socket.local_addr()?,
|
||||
};
|
||||
|
||||
let server_config = new_connections
|
||||
.is_some()
|
||||
.then_some(quinn_config.server_config);
|
||||
|
||||
let provider_socket = P::from_socket(socket)?;
|
||||
|
||||
let driver = Driver::<P>::new(
|
||||
quinn_config.endpoint_config,
|
||||
quinn_config.client_config,
|
||||
new_connections,
|
||||
server_config,
|
||||
channel.clone(),
|
||||
provider_socket,
|
||||
to_endpoint_rx,
|
||||
);
|
||||
|
||||
// Drive the endpoint future in the background.
|
||||
P::spawn(driver);
|
||||
|
||||
Ok(channel)
|
||||
}
|
||||
|
||||
pub(crate) fn socket_addr(&self) -> &SocketAddr {
|
||||
&self.socket_addr
|
||||
}
|
||||
|
||||
/// Try to send a message to the background task without blocking.
|
||||
///
|
||||
/// This first polls the channel for capacity.
|
||||
/// If the channel is full, the message is returned in `Ok(Err(_))`
|
||||
/// and the context's waker is registered for wake-up.
|
||||
///
|
||||
/// If the background task crashed `Err` is returned.
|
||||
pub(crate) fn try_send(
|
||||
&mut self,
|
||||
to_endpoint: ToEndpoint,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Result<Result<(), ToEndpoint>, Disconnected> {
|
||||
match self.to_endpoint.poll_ready_unpin(cx) {
|
||||
Poll::Ready(Ok(())) => {}
|
||||
Poll::Ready(Err(e)) => {
|
||||
debug_assert!(
|
||||
e.is_disconnected(),
|
||||
"mpsc::Sender can only be disconnected when calling `poll_ready_unpin"
|
||||
);
|
||||
|
||||
return Err(Disconnected {});
|
||||
}
|
||||
Poll::Pending => return Ok(Err(to_endpoint)),
|
||||
};
|
||||
|
||||
if let Err(e) = self.to_endpoint.start_send(to_endpoint) {
|
||||
debug_assert!(e.is_disconnected(), "We called `Sink::poll_ready` so we are guaranteed to have a slot. If this fails, it means we are disconnected.");
|
||||
|
||||
return Err(Disconnected {});
|
||||
}
|
||||
|
||||
Ok(Ok(()))
|
||||
}
|
||||
|
||||
pub(crate) async fn send(&mut self, to_endpoint: ToEndpoint) -> Result<(), Disconnected> {
|
||||
self.to_endpoint
|
||||
.send(to_endpoint)
|
||||
.await
|
||||
.map_err(|_| Disconnected {})
|
||||
}
|
||||
|
||||
/// Send a message to inform the [`Driver`] about an
|
||||
/// event caused by the owner of this [`Channel`] dropping.
|
||||
/// This clones the sender to the endpoint to guarantee delivery.
|
||||
/// This should *not* be called for regular messages.
|
||||
pub(crate) fn send_on_drop(&mut self, to_endpoint: ToEndpoint) {
|
||||
let _ = self.to_endpoint.clone().try_send(to_endpoint);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
|
||||
#[error("Background task disconnected")]
|
||||
pub(crate) struct Disconnected {}
|
||||
/// Message sent to the endpoint background task.
|
||||
#[derive(Debug)]
|
||||
pub(crate) enum ToEndpoint {
|
||||
/// Instruct the [`quinn_proto::Endpoint`] to start connecting to the given address.
|
||||
Dial {
|
||||
/// UDP address to connect to.
|
||||
addr: SocketAddr,
|
||||
/// Version to dial the remote on.
|
||||
version: ProtocolVersion,
|
||||
/// Channel to return the result of the dialing to.
|
||||
result: oneshot::Sender<Result<Connection, Error>>,
|
||||
},
|
||||
/// Send by a [`quinn_proto::Connection`] when the endpoint needs to process an event generated
|
||||
/// by a connection. The event itself is opaque to us. Only `quinn_proto` knows what is in
|
||||
/// there.
|
||||
ProcessConnectionEvent {
|
||||
connection_id: quinn_proto::ConnectionHandle,
|
||||
event: quinn_proto::EndpointEvent,
|
||||
},
|
||||
/// Instruct the endpoint to send a packet of data on its UDP socket.
|
||||
SendUdpPacket(quinn_proto::Transmit),
|
||||
/// The [`GenTransport`][crate::GenTransport] dialer or listener coupled to this endpoint
|
||||
/// was dropped.
|
||||
/// Once all pending connections are closed, the [`Driver`] should shut down.
|
||||
Decoupled,
|
||||
}
|
||||
|
||||
/// Driver that runs in the background for as long as the endpoint is alive. Responsible for
|
||||
/// processing messages and the UDP socket.
|
||||
///
|
||||
/// # Behaviour
|
||||
///
|
||||
/// This background task is responsible for the following:
|
||||
///
|
||||
/// - Sending packets on the UDP socket.
|
||||
/// - Receiving packets from the UDP socket and feed them to the [`quinn_proto::Endpoint`] state
|
||||
/// machine.
|
||||
/// - Transmitting events generated by the [`quinn_proto::Endpoint`] to the corresponding
|
||||
/// [`crate::Connection`].
|
||||
/// - Receiving messages from the `rx` and processing the requested actions. This includes
|
||||
/// UDP packets to send and events emitted by the [`crate::Connection`] objects.
|
||||
/// - Sending new connections on `new_connection_tx`.
|
||||
///
|
||||
/// When it comes to channels, there exists three main multi-producer-single-consumer channels
|
||||
/// in play:
|
||||
///
|
||||
/// - One channel, represented by `EndpointChannel::to_endpoint` and `Driver::rx`,
|
||||
/// that communicates messages from [`Channel`] to the [`Driver`].
|
||||
/// - One channel for each existing connection that communicates messages from the
|
||||
/// [`Driver` to that [`crate::Connection`].
|
||||
/// - One channel for the [`Driver`] to send newly-opened connections to. The receiving
|
||||
/// side is processed by the [`GenTransport`][crate::GenTransport].
|
||||
///
|
||||
///
|
||||
/// ## Back-pressure
|
||||
///
|
||||
/// ### If writing to the UDP socket is blocked
|
||||
///
|
||||
/// In order to avoid an unbounded buffering of events, we prioritize sending data on the UDP
|
||||
/// socket over everything else. Messages from the rest of the application sent through the
|
||||
/// [`Channel`] are only processed if the UDP socket is ready so that we propagate back-pressure
|
||||
/// in case of a busy socket. For connections, thus this eventually also back-pressures the
|
||||
/// `AsyncWrite`on substreams.
|
||||
///
|
||||
///
|
||||
/// ### Back-pressuring the remote if the application is busy
|
||||
///
|
||||
/// If the channel to a connection is full because the connection is busy, inbound datagrams
|
||||
/// for that connection are dropped so that the remote is backpressured.
|
||||
/// The same applies for new connections if the transport is too busy to received it.
|
||||
///
|
||||
///
|
||||
/// # Shutdown
|
||||
///
|
||||
/// The background task shuts down if an [`ToEndpoint::Decoupled`] event was received and the
|
||||
/// last active connection has drained.
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Driver<P: Provider> {
|
||||
// The actual QUIC state machine.
|
||||
endpoint: quinn_proto::Endpoint,
|
||||
// QuinnConfig for client connections.
|
||||
client_config: quinn_proto::ClientConfig,
|
||||
// Copy of the channel to the endpoint driver that is passed to each new connection.
|
||||
channel: Channel,
|
||||
// Channel to receive messages from the transport or connections.
|
||||
rx: mpsc::Receiver<ToEndpoint>,
|
||||
|
||||
// Socket for sending and receiving datagrams.
|
||||
provider_socket: P,
|
||||
// Future for writing the next packet to the socket.
|
||||
next_packet_out: Option<quinn_proto::Transmit>,
|
||||
|
||||
// List of all active connections, with a sender to notify them of events.
|
||||
alive_connections:
|
||||
HashMap<quinn_proto::ConnectionHandle, mpsc::Sender<quinn_proto::ConnectionEvent>>,
|
||||
// Channel to forward new inbound connections to the transport.
|
||||
// `None` if server capabilities are disabled, i.e. the endpoint is only used for dialing.
|
||||
new_connection_tx: Option<mpsc::Sender<Connection>>,
|
||||
// Whether the transport dropped its handle for this endpoint.
|
||||
is_decoupled: bool,
|
||||
}
|
||||
|
||||
impl<P: Provider> Driver<P> {
|
||||
fn new(
|
||||
endpoint_config: Arc<quinn_proto::EndpointConfig>,
|
||||
client_config: quinn_proto::ClientConfig,
|
||||
new_connection_tx: Option<mpsc::Sender<Connection>>,
|
||||
server_config: Option<Arc<quinn_proto::ServerConfig>>,
|
||||
channel: Channel,
|
||||
socket: P,
|
||||
rx: mpsc::Receiver<ToEndpoint>,
|
||||
) -> Self {
|
||||
Driver {
|
||||
endpoint: quinn_proto::Endpoint::new(endpoint_config, server_config, false),
|
||||
client_config,
|
||||
channel,
|
||||
rx,
|
||||
provider_socket: socket,
|
||||
next_packet_out: None,
|
||||
alive_connections: HashMap::new(),
|
||||
new_connection_tx,
|
||||
is_decoupled: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a message sent from either the [`GenTransport`](super::GenTransport)
|
||||
/// or a [`crate::Connection`].
|
||||
fn handle_message(
|
||||
&mut self,
|
||||
to_endpoint: ToEndpoint,
|
||||
) -> ControlFlow<(), Option<quinn_proto::Transmit>> {
|
||||
match to_endpoint {
|
||||
ToEndpoint::Dial {
|
||||
addr,
|
||||
result,
|
||||
version,
|
||||
} => {
|
||||
let mut config = self.client_config.clone();
|
||||
if version == ProtocolVersion::Draft29 {
|
||||
config.version(0xff00_001d);
|
||||
}
|
||||
// This `"l"` seems necessary because an empty string is an invalid domain
|
||||
// name. While we don't use domain names, the underlying rustls library
|
||||
// is based upon the assumption that we do.
|
||||
let (connection_id, connection) = match self.endpoint.connect(config, addr, "l") {
|
||||
Ok(c) => c,
|
||||
Err(err) => {
|
||||
let _ = result.send(Err(ConnectError::from(err).into()));
|
||||
return ControlFlow::Continue(None);
|
||||
}
|
||||
};
|
||||
|
||||
debug_assert_eq!(connection.side(), quinn_proto::Side::Client);
|
||||
let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let connection = Connection::from_quinn_connection(
|
||||
self.channel.clone(),
|
||||
connection,
|
||||
connection_id,
|
||||
rx,
|
||||
);
|
||||
self.alive_connections.insert(connection_id, tx);
|
||||
let _ = result.send(Ok(connection));
|
||||
}
|
||||
|
||||
// A connection wants to notify the endpoint of something.
|
||||
ToEndpoint::ProcessConnectionEvent {
|
||||
connection_id,
|
||||
event,
|
||||
} => {
|
||||
let has_key = self.alive_connections.contains_key(&connection_id);
|
||||
if !has_key {
|
||||
return ControlFlow::Continue(None);
|
||||
}
|
||||
// We "drained" event indicates that the connection no longer exists and
|
||||
// its ID can be reclaimed.
|
||||
let is_drained_event = event.is_drained();
|
||||
if is_drained_event {
|
||||
self.alive_connections.remove(&connection_id);
|
||||
if self.is_decoupled && self.alive_connections.is_empty() {
|
||||
log::debug!(
|
||||
"Driver is decoupled and no active connections remain. Shutting down."
|
||||
);
|
||||
return ControlFlow::Break(());
|
||||
}
|
||||
}
|
||||
|
||||
let event_back = self.endpoint.handle_event(connection_id, event);
|
||||
|
||||
if let Some(event_back) = event_back {
|
||||
debug_assert!(!is_drained_event);
|
||||
if let Some(sender) = self.alive_connections.get_mut(&connection_id) {
|
||||
// We clone the sender to guarantee that there will be at least one
|
||||
// free slot to send the event.
|
||||
// The channel can not grow out of bound because an `event_back` is
|
||||
// only sent if we previously received an event from the same connection.
|
||||
// If the connection is busy, it won't sent us any more events to handle.
|
||||
let _ = sender.clone().start_send(event_back);
|
||||
} else {
|
||||
log::error!("State mismatch: event for closed connection");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Data needs to be sent on the UDP socket.
|
||||
ToEndpoint::SendUdpPacket(transmit) => return ControlFlow::Continue(Some(transmit)),
|
||||
ToEndpoint::Decoupled => self.handle_decoupling()?,
|
||||
}
|
||||
ControlFlow::Continue(None)
|
||||
}
|
||||
|
||||
/// Handle an UDP datagram received on the socket.
|
||||
/// The datagram content was written into the `socket_recv_buffer`.
|
||||
fn handle_datagram(&mut self, packet: BytesMut, packet_src: SocketAddr) -> ControlFlow<()> {
|
||||
let local_ip = self.channel.socket_addr.ip();
|
||||
// TODO: ECN bits aren't handled
|
||||
let (connec_id, event) =
|
||||
match self
|
||||
.endpoint
|
||||
.handle(Instant::now(), packet_src, Some(local_ip), None, packet)
|
||||
{
|
||||
Some(event) => event,
|
||||
None => return ControlFlow::Continue(()),
|
||||
};
|
||||
match event {
|
||||
quinn_proto::DatagramEvent::ConnectionEvent(event) => {
|
||||
// `event` has type `quinn_proto::ConnectionEvent`, which has multiple
|
||||
// variants. `quinn_proto::Endpoint::handle` however only ever returns
|
||||
// `ConnectionEvent::Datagram`.
|
||||
debug_assert!(format!("{event:?}").contains("Datagram"));
|
||||
|
||||
// Redirect the datagram to its connection.
|
||||
if let Some(sender) = self.alive_connections.get_mut(&connec_id) {
|
||||
match sender.try_send(event) {
|
||||
Ok(()) => {}
|
||||
Err(err) if err.is_disconnected() => {
|
||||
// Connection was dropped by the user.
|
||||
// Inform the endpoint that this connection is drained.
|
||||
self.endpoint
|
||||
.handle_event(connec_id, quinn_proto::EndpointEvent::drained());
|
||||
self.alive_connections.remove(&connec_id);
|
||||
}
|
||||
Err(err) if err.is_full() => {
|
||||
// Connection is too busy. Drop the datagram to back-pressure the remote.
|
||||
log::debug!(
|
||||
"Dropping packet for connection {:?} because the connection's channel is full.",
|
||||
connec_id
|
||||
);
|
||||
}
|
||||
Err(_) => unreachable!("Error is either `Full` or `Disconnected`."),
|
||||
}
|
||||
} else {
|
||||
log::error!("State mismatch: event for closed connection");
|
||||
}
|
||||
}
|
||||
quinn_proto::DatagramEvent::NewConnection(connec) => {
|
||||
// A new connection has been received. `connec_id` is a newly-allocated
|
||||
// identifier.
|
||||
debug_assert_eq!(connec.side(), quinn_proto::Side::Server);
|
||||
let connection_tx = match self.new_connection_tx.as_mut() {
|
||||
Some(tx) => tx,
|
||||
None => {
|
||||
debug_assert!(false, "Endpoint reported a new connection even though server capabilities are disabled.");
|
||||
return ControlFlow::Continue(());
|
||||
}
|
||||
};
|
||||
|
||||
let (tx, rx) = mpsc::channel(CHANNEL_CAPACITY);
|
||||
let connection =
|
||||
Connection::from_quinn_connection(self.channel.clone(), connec, connec_id, rx);
|
||||
match connection_tx.start_send(connection) {
|
||||
Ok(()) => {
|
||||
self.alive_connections.insert(connec_id, tx);
|
||||
}
|
||||
Err(e) if e.is_disconnected() => self.handle_decoupling()?,
|
||||
Err(e) if e.is_full() => log::warn!(
|
||||
"Dropping new incoming connection {:?} because the channel to the listener is full",
|
||||
connec_id
|
||||
),
|
||||
Err(_) => unreachable!("Error is either `Full` or `Disconnected`."),
|
||||
}
|
||||
}
|
||||
}
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
/// The transport dropped the channel to this [`Driver`].
|
||||
fn handle_decoupling(&mut self) -> ControlFlow<()> {
|
||||
if self.alive_connections.is_empty() {
|
||||
return ControlFlow::Break(());
|
||||
}
|
||||
// Listener was closed.
|
||||
self.endpoint.reject_new_connections();
|
||||
self.new_connection_tx = None;
|
||||
self.is_decoupled = true;
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Future that runs until the [`Driver`] is decoupled and not active connections
|
||||
/// remain
|
||||
impl<P: Provider> Future for Driver<P> {
|
||||
type Output = ();
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
loop {
|
||||
// Flush any pending pocket so that the socket is reading to write an next
|
||||
// packet.
|
||||
match self.provider_socket.poll_send_flush(cx) {
|
||||
// The pending packet was send or no packet was pending.
|
||||
Poll::Ready(Ok(_)) => {
|
||||
// Start sending a packet on the socket.
|
||||
if let Some(transmit) = self.next_packet_out.take() {
|
||||
self.provider_socket
|
||||
.start_send(transmit.contents.into(), transmit.destination);
|
||||
continue;
|
||||
}
|
||||
|
||||
// The endpoint might request packets to be sent out. This is handled in
|
||||
// priority to avoid buffering up packets.
|
||||
if let Some(transmit) = self.endpoint.poll_transmit() {
|
||||
self.next_packet_out = Some(transmit);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle messages from transport and connections.
|
||||
match self.rx.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(to_endpoint)) => match self.handle_message(to_endpoint) {
|
||||
ControlFlow::Continue(Some(transmit)) => {
|
||||
self.next_packet_out = Some(transmit);
|
||||
continue;
|
||||
}
|
||||
ControlFlow::Continue(None) => continue,
|
||||
ControlFlow::Break(()) => break,
|
||||
},
|
||||
Poll::Ready(None) => {
|
||||
unreachable!("Sender side is never dropped or closed.")
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
}
|
||||
// Errors on the socket are expected to never happen, and we handle them by simply
|
||||
// printing a log message. The packet gets discarded in case of error, but we are
|
||||
// robust to packet losses and it is consequently not a logic error to proceed with
|
||||
// normal operations.
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::warn!("Error while sending on QUIC UDP socket: {:?}", err);
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
// Poll for new packets from the remote.
|
||||
match self.provider_socket.poll_recv_from(cx) {
|
||||
Poll::Ready(Ok((bytes, packet_src))) => {
|
||||
let bytes_mut = bytes.as_slice().into();
|
||||
match self.handle_datagram(bytes_mut, packet_src) {
|
||||
ControlFlow::Continue(()) => continue,
|
||||
ControlFlow::Break(()) => break,
|
||||
}
|
||||
}
|
||||
// Errors on the socket are expected to never happen, and we handle them by
|
||||
// simply printing a log message.
|
||||
Poll::Ready(Err(err)) => {
|
||||
log::warn!("Error while receive on QUIC UDP socket: {:?}", err);
|
||||
continue;
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
|
||||
Poll::Ready(())
|
||||
}
|
||||
}
|
@ -1,19 +1,20 @@
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
use crate::{provider::Provider, Error};
|
||||
|
||||
use futures::future::Either;
|
||||
|
||||
use rand::{distributions, Rng};
|
||||
|
||||
use crate::{
|
||||
endpoint::{self, ToEndpoint},
|
||||
Error, Provider,
|
||||
use std::{
|
||||
net::{SocketAddr, UdpSocket},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
pub(crate) async fn hole_puncher<P: Provider>(
|
||||
endpoint_channel: endpoint::Channel,
|
||||
socket: UdpSocket,
|
||||
remote_addr: SocketAddr,
|
||||
timeout_duration: Duration,
|
||||
) -> Error {
|
||||
let punch_holes_future = punch_holes::<P>(endpoint_channel, remote_addr);
|
||||
let punch_holes_future = punch_holes::<P>(socket, remote_addr);
|
||||
futures::pin_mut!(punch_holes_future);
|
||||
match futures::future::select(P::sleep(timeout_duration), punch_holes_future).await {
|
||||
Either::Left(_) => Error::HandshakeTimedOut,
|
||||
@ -21,27 +22,18 @@ pub(crate) async fn hole_puncher<P: Provider>(
|
||||
}
|
||||
}
|
||||
|
||||
async fn punch_holes<P: Provider>(
|
||||
mut endpoint_channel: endpoint::Channel,
|
||||
remote_addr: SocketAddr,
|
||||
) -> Error {
|
||||
async fn punch_holes<P: Provider>(socket: UdpSocket, remote_addr: SocketAddr) -> Error {
|
||||
loop {
|
||||
let sleep_duration = Duration::from_millis(rand::thread_rng().gen_range(10..=200));
|
||||
P::sleep(sleep_duration).await;
|
||||
|
||||
let random_udp_packet = ToEndpoint::SendUdpPacket(quinn_proto::Transmit {
|
||||
destination: remote_addr,
|
||||
ecn: None,
|
||||
contents: rand::thread_rng()
|
||||
.sample_iter(distributions::Standard)
|
||||
.take(64)
|
||||
.collect(),
|
||||
segment_size: None,
|
||||
src_ip: None,
|
||||
});
|
||||
let contents: Vec<u8> = rand::thread_rng()
|
||||
.sample_iter(distributions::Standard)
|
||||
.take(64)
|
||||
.collect();
|
||||
|
||||
if endpoint_channel.send(random_udp_packet).await.is_err() {
|
||||
return Error::EndpointDriverCrashed;
|
||||
if let Err(e) = P::send_to(&socket, &contents, remote_addr).await {
|
||||
return Error::Io(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -57,16 +57,17 @@
|
||||
|
||||
#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
|
||||
|
||||
mod config;
|
||||
mod connection;
|
||||
mod endpoint;
|
||||
mod hole_punching;
|
||||
mod provider;
|
||||
mod transport;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
|
||||
pub use connection::{Connecting, Connection, Substream};
|
||||
pub use endpoint::Config;
|
||||
pub use config::Config;
|
||||
pub use connection::{Connecting, Connection, Stream};
|
||||
|
||||
#[cfg(feature = "async-std")]
|
||||
pub use provider::async_std;
|
||||
#[cfg(feature = "tokio")]
|
||||
@ -89,8 +90,7 @@ pub enum Error {
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// The task spawned in [`Provider::spawn`] to drive
|
||||
/// the quic endpoint has crashed.
|
||||
/// The task to drive a quic endpoint has crashed.
|
||||
#[error("Endpoint driver crashed")]
|
||||
EndpointDriverCrashed,
|
||||
|
||||
@ -110,9 +110,9 @@ pub enum Error {
|
||||
/// Dialing a remote peer failed.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct ConnectError(#[from] quinn_proto::ConnectError);
|
||||
pub struct ConnectError(quinn::ConnectError);
|
||||
|
||||
/// Error on an established [`Connection`].
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error(transparent)]
|
||||
pub struct ConnectionError(#[from] quinn_proto::ConnectionError);
|
||||
pub struct ConnectionError(quinn::ConnectionError);
|
||||
|
@ -18,11 +18,11 @@
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::{future::BoxFuture, Future};
|
||||
use futures::future::BoxFuture;
|
||||
use if_watch::IfEvent;
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
@ -32,40 +32,20 @@ pub mod async_std;
|
||||
#[cfg(feature = "tokio")]
|
||||
pub mod tokio;
|
||||
|
||||
/// Size of the buffer for reading data 0x10000.
|
||||
#[cfg(any(feature = "async-std", feature = "tokio"))]
|
||||
const RECEIVE_BUFFER_SIZE: usize = 65536;
|
||||
pub enum Runtime {
|
||||
#[cfg(feature = "tokio")]
|
||||
Tokio,
|
||||
#[cfg(feature = "async-std")]
|
||||
AsyncStd,
|
||||
Dummy,
|
||||
}
|
||||
|
||||
/// Provider for non-blocking receiving and sending on a [`std::net::UdpSocket`]
|
||||
/// and spawning tasks.
|
||||
/// Provider for a corresponding quinn runtime and spawning tasks.
|
||||
pub trait Provider: Unpin + Send + Sized + 'static {
|
||||
type IfWatcher: Unpin + Send;
|
||||
|
||||
/// Create a new providing that is wrapping the socket.
|
||||
///
|
||||
/// Note: The socket must be set to non-blocking.
|
||||
fn from_socket(socket: std::net::UdpSocket) -> io::Result<Self>;
|
||||
|
||||
/// Receive a single packet.
|
||||
///
|
||||
/// Returns the message and the address the message came from.
|
||||
fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<(Vec<u8>, SocketAddr)>>;
|
||||
|
||||
/// Set sending a packet on the socket.
|
||||
///
|
||||
/// Since only one packet can be sent at a time, this may only be called if a preceding
|
||||
/// call to [`Provider::poll_send_flush`] returned [`Poll::Ready`].
|
||||
fn start_send(&mut self, data: Vec<u8>, addr: SocketAddr);
|
||||
|
||||
/// Flush a packet send in [`Provider::start_send`].
|
||||
///
|
||||
/// If [`Poll::Ready`] is returned the socket is ready for sending a new packet.
|
||||
fn poll_send_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>;
|
||||
|
||||
/// Run the given future in the background until it ends.
|
||||
///
|
||||
/// This is used to spawn the task that is driving the endpoint.
|
||||
fn spawn(future: impl Future<Output = ()> + Send + 'static);
|
||||
/// Run the corresponding runtime
|
||||
fn runtime() -> Runtime;
|
||||
|
||||
/// Create a new [`if_watch`] watcher that reports [`IfEvent`]s for network interface changes.
|
||||
fn new_if_watcher() -> io::Result<Self::IfWatcher>;
|
||||
@ -78,4 +58,11 @@ pub trait Provider: Unpin + Send + Sized + 'static {
|
||||
|
||||
/// Sleep for specified amount of time.
|
||||
fn sleep(duration: Duration) -> BoxFuture<'static, ()>;
|
||||
|
||||
/// Sends data on the socket to the given address. On success, returns the number of bytes written.
|
||||
fn send_to<'a>(
|
||||
udp_socket: &'a UdpSocket,
|
||||
buf: &'a [u8],
|
||||
target: SocketAddr,
|
||||
) -> BoxFuture<'a, io::Result<usize>>;
|
||||
}
|
||||
|
@ -18,13 +18,10 @@
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use async_std::{net::UdpSocket, task::spawn};
|
||||
use futures::{future::BoxFuture, ready, Future, FutureExt, Stream, StreamExt};
|
||||
use futures::{future::BoxFuture, FutureExt};
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
net::UdpSocket,
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
@ -34,65 +31,14 @@ use crate::GenTransport;
|
||||
/// Transport with [`async-std`] runtime.
|
||||
pub type Transport = GenTransport<Provider>;
|
||||
|
||||
/// Provider for reading / writing to a sockets and spawning
|
||||
/// tasks using [`async-std`].
|
||||
pub struct Provider {
|
||||
socket: Arc<UdpSocket>,
|
||||
// Future for sending a packet.
|
||||
// This is needed since [`async_Std::net::UdpSocket`] does not
|
||||
// provide a poll-style interface for sending a packet.
|
||||
send_packet: Option<BoxFuture<'static, Result<(), io::Error>>>,
|
||||
recv_stream: ReceiveStream,
|
||||
}
|
||||
/// Provider for quinn runtime and spawning tasks using [`async-std`].
|
||||
pub struct Provider;
|
||||
|
||||
impl super::Provider for Provider {
|
||||
type IfWatcher = if_watch::smol::IfWatcher;
|
||||
|
||||
fn from_socket(socket: std::net::UdpSocket) -> io::Result<Self> {
|
||||
let socket = Arc::new(socket.into());
|
||||
let recv_stream = ReceiveStream::new(Arc::clone(&socket));
|
||||
Ok(Provider {
|
||||
socket,
|
||||
send_packet: None,
|
||||
recv_stream,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<(Vec<u8>, SocketAddr)>> {
|
||||
match self.recv_stream.poll_next_unpin(cx) {
|
||||
Poll::Ready(ready) => {
|
||||
Poll::Ready(ready.expect("ReceiveStream::poll_next never returns None."))
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(&mut self, data: Vec<u8>, addr: SocketAddr) {
|
||||
let socket = self.socket.clone();
|
||||
let send = async move {
|
||||
socket.send_to(&data, addr).await?;
|
||||
Ok(())
|
||||
}
|
||||
.boxed();
|
||||
self.send_packet = Some(send)
|
||||
}
|
||||
|
||||
fn poll_send_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
let pending = match self.send_packet.as_mut() {
|
||||
Some(pending) => pending,
|
||||
None => return Poll::Ready(Ok(())),
|
||||
};
|
||||
match pending.poll_unpin(cx) {
|
||||
Poll::Ready(result) => {
|
||||
self.send_packet = None;
|
||||
Poll::Ready(result)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn(future: impl Future<Output = ()> + Send + 'static) {
|
||||
spawn(future);
|
||||
fn runtime() -> super::Runtime {
|
||||
super::Runtime::AsyncStd
|
||||
}
|
||||
|
||||
fn new_if_watcher() -> io::Result<Self::IfWatcher> {
|
||||
@ -109,48 +55,16 @@ impl super::Provider for Provider {
|
||||
fn sleep(duration: Duration) -> BoxFuture<'static, ()> {
|
||||
async_std::task::sleep(duration).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
type ReceiveStreamItem = (
|
||||
Result<(usize, SocketAddr), io::Error>,
|
||||
Arc<UdpSocket>,
|
||||
Vec<u8>,
|
||||
);
|
||||
|
||||
/// Wrapper around the socket to implement `Stream` on it.
|
||||
struct ReceiveStream {
|
||||
/// Future for receiving a packet on the socket.
|
||||
// This is needed since [`async_Std::net::UdpSocket`] does not
|
||||
// provide a poll-style interface for receiving packets.
|
||||
fut: BoxFuture<'static, ReceiveStreamItem>,
|
||||
}
|
||||
|
||||
impl ReceiveStream {
|
||||
fn new(socket: Arc<UdpSocket>) -> Self {
|
||||
let fut = ReceiveStream::next(socket, vec![0; super::RECEIVE_BUFFER_SIZE]).boxed();
|
||||
Self { fut: fut.boxed() }
|
||||
}
|
||||
|
||||
async fn next(socket: Arc<UdpSocket>, mut socket_recv_buffer: Vec<u8>) -> ReceiveStreamItem {
|
||||
let recv = socket.recv_from(&mut socket_recv_buffer).await;
|
||||
(recv, socket, socket_recv_buffer)
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for ReceiveStream {
|
||||
type Item = Result<(Vec<u8>, SocketAddr), io::Error>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let (result, socket, buffer) = ready!(self.fut.poll_unpin(cx));
|
||||
|
||||
let result = result.map(|(packet_len, packet_src)| {
|
||||
debug_assert!(packet_len <= buffer.len());
|
||||
// Copies the bytes from the `socket_recv_buffer` they were written into.
|
||||
(buffer[..packet_len].into(), packet_src)
|
||||
});
|
||||
// Set the future for receiving the next packet on the stream.
|
||||
self.fut = ReceiveStream::next(socket, buffer).boxed();
|
||||
|
||||
Poll::Ready(Some(result))
|
||||
fn send_to<'a>(
|
||||
udp_socket: &'a UdpSocket,
|
||||
buf: &'a [u8],
|
||||
target: std::net::SocketAddr,
|
||||
) -> BoxFuture<'a, io::Result<usize>> {
|
||||
Box::pin(async move {
|
||||
async_std::net::UdpSocket::from(udp_socket.try_clone()?)
|
||||
.send_to(buf, target)
|
||||
.await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -18,72 +18,27 @@
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use futures::{future::BoxFuture, ready, Future, FutureExt};
|
||||
use futures::{future::BoxFuture, FutureExt};
|
||||
use std::{
|
||||
io,
|
||||
net::SocketAddr,
|
||||
net::{SocketAddr, UdpSocket},
|
||||
task::{Context, Poll},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{io::ReadBuf, net::UdpSocket};
|
||||
|
||||
use crate::GenTransport;
|
||||
|
||||
/// Transport with [`tokio`] runtime.
|
||||
pub type Transport = GenTransport<Provider>;
|
||||
|
||||
/// Provider for reading / writing to a sockets and spawning
|
||||
/// tasks using [`tokio`].
|
||||
pub struct Provider {
|
||||
socket: UdpSocket,
|
||||
socket_recv_buffer: Vec<u8>,
|
||||
next_packet_out: Option<(Vec<u8>, SocketAddr)>,
|
||||
}
|
||||
/// Provider for quinn runtime and spawning tasks using [`tokio`].
|
||||
pub struct Provider;
|
||||
|
||||
impl super::Provider for Provider {
|
||||
type IfWatcher = if_watch::tokio::IfWatcher;
|
||||
|
||||
fn from_socket(socket: std::net::UdpSocket) -> std::io::Result<Self> {
|
||||
let socket = UdpSocket::from_std(socket)?;
|
||||
Ok(Provider {
|
||||
socket,
|
||||
socket_recv_buffer: vec![0; super::RECEIVE_BUFFER_SIZE],
|
||||
next_packet_out: None,
|
||||
})
|
||||
}
|
||||
|
||||
fn poll_send_flush(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
|
||||
let (data, addr) = match self.next_packet_out.as_ref() {
|
||||
Some(pending) => pending,
|
||||
None => return Poll::Ready(Ok(())),
|
||||
};
|
||||
match self.socket.poll_send_to(cx, data.as_slice(), *addr) {
|
||||
Poll::Ready(result) => {
|
||||
self.next_packet_out = None;
|
||||
Poll::Ready(result.map(|_| ()))
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_recv_from(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<(Vec<u8>, SocketAddr)>> {
|
||||
let Self {
|
||||
socket,
|
||||
socket_recv_buffer,
|
||||
..
|
||||
} = self;
|
||||
let mut read_buf = ReadBuf::new(socket_recv_buffer.as_mut_slice());
|
||||
let packet_src = ready!(socket.poll_recv_from(cx, &mut read_buf)?);
|
||||
let bytes = read_buf.filled().to_vec();
|
||||
Poll::Ready(Ok((bytes, packet_src)))
|
||||
}
|
||||
|
||||
fn start_send(&mut self, data: Vec<u8>, addr: SocketAddr) {
|
||||
self.next_packet_out = Some((data, addr));
|
||||
}
|
||||
|
||||
fn spawn(future: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(future);
|
||||
fn runtime() -> super::Runtime {
|
||||
super::Runtime::Tokio
|
||||
}
|
||||
|
||||
fn new_if_watcher() -> io::Result<Self::IfWatcher> {
|
||||
@ -100,4 +55,16 @@ impl super::Provider for Provider {
|
||||
fn sleep(duration: Duration) -> BoxFuture<'static, ()> {
|
||||
tokio::time::sleep(duration).boxed()
|
||||
}
|
||||
|
||||
fn send_to<'a>(
|
||||
udp_socket: &'a UdpSocket,
|
||||
buf: &'a [u8],
|
||||
target: SocketAddr,
|
||||
) -> BoxFuture<'a, io::Result<usize>> {
|
||||
Box::pin(async move {
|
||||
tokio::net::UdpSocket::from_std(udp_socket.try_clone()?)?
|
||||
.send_to(buf, target)
|
||||
.await
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -18,12 +18,12 @@
|
||||
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
// DEALINGS IN THE SOFTWARE.
|
||||
|
||||
use crate::endpoint::{Config, QuinnConfig, ToEndpoint};
|
||||
use crate::config::{Config, QuinnConfig};
|
||||
use crate::hole_punching::hole_puncher;
|
||||
use crate::provider::Provider;
|
||||
use crate::{endpoint, Connecting, Connection, Error};
|
||||
use crate::{ConnectError, Connecting, Connection, Error};
|
||||
|
||||
use futures::channel::{mpsc, oneshot};
|
||||
use futures::channel::oneshot;
|
||||
use futures::future::{BoxFuture, Either};
|
||||
use futures::ready;
|
||||
use futures::stream::StreamExt;
|
||||
@ -38,10 +38,10 @@ use libp2p_core::{
|
||||
};
|
||||
use libp2p_identity::PeerId;
|
||||
use std::collections::hash_map::{DefaultHasher, Entry};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::net::IpAddr;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, UdpSocket};
|
||||
use std::time::Duration;
|
||||
use std::{
|
||||
net::SocketAddr,
|
||||
@ -62,7 +62,7 @@ use std::{
|
||||
/// See <https://github.com/multiformats/multiaddr/issues/145>.
|
||||
#[derive(Debug)]
|
||||
pub struct GenTransport<P: Provider> {
|
||||
/// Config for the inner [`quinn_proto`] structs.
|
||||
/// Config for the inner [`quinn`] structs.
|
||||
quinn_config: QuinnConfig,
|
||||
/// Timeout for the [`Connecting`] future.
|
||||
handshake_timeout: Duration,
|
||||
@ -71,7 +71,7 @@ pub struct GenTransport<P: Provider> {
|
||||
/// Streams of active [`Listener`]s.
|
||||
listeners: SelectAll<Listener<P>>,
|
||||
/// Dialer for each socket family if no matching listener exists.
|
||||
dialer: HashMap<SocketFamily, Dialer>,
|
||||
dialer: HashMap<SocketFamily, quinn::Endpoint>,
|
||||
/// Waker to poll the transport again when a new dialer or listener is added.
|
||||
waker: Option<Waker>,
|
||||
/// Holepunching attempts
|
||||
@ -95,21 +95,57 @@ impl<P: Provider> GenTransport<P> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new [`quinn::Endpoint`] with the given configs.
|
||||
fn new_endpoint(
|
||||
endpoint_config: quinn::EndpointConfig,
|
||||
server_config: Option<quinn::ServerConfig>,
|
||||
socket: UdpSocket,
|
||||
) -> Result<quinn::Endpoint, Error> {
|
||||
use crate::provider::Runtime;
|
||||
match P::runtime() {
|
||||
#[cfg(feature = "tokio")]
|
||||
Runtime::Tokio => {
|
||||
let runtime = std::sync::Arc::new(quinn::TokioRuntime);
|
||||
let endpoint =
|
||||
quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
|
||||
Ok(endpoint)
|
||||
}
|
||||
#[cfg(feature = "async-std")]
|
||||
Runtime::AsyncStd => {
|
||||
let runtime = std::sync::Arc::new(quinn::AsyncStdRuntime);
|
||||
let endpoint =
|
||||
quinn::Endpoint::new(endpoint_config, server_config, socket, runtime)?;
|
||||
Ok(endpoint)
|
||||
}
|
||||
Runtime::Dummy => {
|
||||
let _ = endpoint_config;
|
||||
let _ = server_config;
|
||||
let _ = socket;
|
||||
let err = std::io::Error::new(std::io::ErrorKind::Other, "no async runtime found");
|
||||
Err(Error::Io(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the addr, quic version and peer id from the given [`Multiaddr`].
|
||||
fn remote_multiaddr_to_socketaddr(
|
||||
&self,
|
||||
addr: Multiaddr,
|
||||
check_unspecified_addr: bool,
|
||||
) -> Result<
|
||||
(SocketAddr, ProtocolVersion, Option<PeerId>),
|
||||
TransportError<<Self as Transport>::Error>,
|
||||
> {
|
||||
let (socket_addr, version, peer_id) = multiaddr_to_socketaddr(&addr, self.support_draft_29)
|
||||
.ok_or_else(|| TransportError::MultiaddrNotSupported(addr.clone()))?;
|
||||
if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() {
|
||||
if check_unspecified_addr && (socket_addr.port() == 0 || socket_addr.ip().is_unspecified())
|
||||
{
|
||||
return Err(TransportError::MultiaddrNotSupported(addr));
|
||||
}
|
||||
Ok((socket_addr, version, peer_id))
|
||||
}
|
||||
|
||||
/// Pick any listener to use for dialing.
|
||||
fn eligible_listener(&mut self, socket_addr: &SocketAddr) -> Option<&mut Listener<P>> {
|
||||
let mut listeners: Vec<_> = self
|
||||
.listeners
|
||||
@ -118,7 +154,7 @@ impl<P: Provider> GenTransport<P> {
|
||||
if l.is_closed {
|
||||
return false;
|
||||
}
|
||||
let listen_addr = l.endpoint_channel.socket_addr();
|
||||
let listen_addr = l.socket_addr();
|
||||
SocketFamily::is_same(&listen_addr.ip(), &socket_addr.ip())
|
||||
&& listen_addr.ip().is_loopback() == socket_addr.ip().is_loopback()
|
||||
})
|
||||
@ -149,13 +185,16 @@ impl<P: Provider> Transport for GenTransport<P> {
|
||||
listener_id: ListenerId,
|
||||
addr: Multiaddr,
|
||||
) -> Result<(), TransportError<Self::Error>> {
|
||||
let (socket_addr, version, _peer_id) =
|
||||
multiaddr_to_socketaddr(&addr, self.support_draft_29)
|
||||
.ok_or(TransportError::MultiaddrNotSupported(addr))?;
|
||||
let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr, false)?;
|
||||
let endpoint_config = self.quinn_config.endpoint_config.clone();
|
||||
let server_config = self.quinn_config.server_config.clone();
|
||||
let socket = UdpSocket::bind(socket_addr).map_err(Self::Error::from)?;
|
||||
let socket_c = socket.try_clone().map_err(Self::Error::from)?;
|
||||
let endpoint = Self::new_endpoint(endpoint_config, Some(server_config), socket)?;
|
||||
let listener = Listener::new(
|
||||
listener_id,
|
||||
socket_addr,
|
||||
self.quinn_config.clone(),
|
||||
socket_c,
|
||||
endpoint,
|
||||
self.handshake_timeout,
|
||||
version,
|
||||
)?;
|
||||
@ -194,46 +233,68 @@ impl<P: Provider> Transport for GenTransport<P> {
|
||||
}
|
||||
|
||||
fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
|
||||
let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr)?;
|
||||
let (socket_addr, version, _peer_id) = self.remote_multiaddr_to_socketaddr(addr, true)?;
|
||||
|
||||
let handshake_timeout = self.handshake_timeout;
|
||||
|
||||
let dialer_state = match self.eligible_listener(&socket_addr) {
|
||||
let endpoint = match self.eligible_listener(&socket_addr) {
|
||||
None => {
|
||||
// No listener. Get or create an explicit dialer.
|
||||
let socket_family = socket_addr.ip().into();
|
||||
let dialer = match self.dialer.entry(socket_family) {
|
||||
Entry::Occupied(occupied) => occupied.into_mut(),
|
||||
Entry::Occupied(occupied) => occupied.get().clone(),
|
||||
Entry::Vacant(vacant) => {
|
||||
if let Some(waker) = self.waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
vacant.insert(Dialer::new::<P>(self.quinn_config.clone(), socket_family)?)
|
||||
let listen_socket_addr = match socket_family {
|
||||
SocketFamily::Ipv4 => SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0),
|
||||
SocketFamily::Ipv6 => SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0),
|
||||
};
|
||||
let socket =
|
||||
UdpSocket::bind(listen_socket_addr).map_err(Self::Error::from)?;
|
||||
let endpoint_config = self.quinn_config.endpoint_config.clone();
|
||||
let endpoint = Self::new_endpoint(endpoint_config, None, socket)?;
|
||||
|
||||
vacant.insert(endpoint.clone());
|
||||
endpoint
|
||||
}
|
||||
};
|
||||
&mut dialer.state
|
||||
dialer
|
||||
}
|
||||
Some(listener) => &mut listener.dialer_state,
|
||||
Some(listener) => listener.endpoint.clone(),
|
||||
};
|
||||
Ok(dialer_state.new_dial(socket_addr, handshake_timeout, version))
|
||||
let handshake_timeout = self.handshake_timeout;
|
||||
let mut client_config = self.quinn_config.client_config.clone();
|
||||
if version == ProtocolVersion::Draft29 {
|
||||
client_config.version(0xff00_001d);
|
||||
}
|
||||
Ok(Box::pin(async move {
|
||||
// This `"l"` seems necessary because an empty string is an invalid domain
|
||||
// name. While we don't use domain names, the underlying rustls library
|
||||
// is based upon the assumption that we do.
|
||||
let connecting = endpoint
|
||||
.connect_with(client_config, socket_addr, "l")
|
||||
.map_err(ConnectError)?;
|
||||
Connecting::new(connecting, handshake_timeout).await
|
||||
}))
|
||||
}
|
||||
|
||||
fn dial_as_listener(
|
||||
&mut self,
|
||||
addr: Multiaddr,
|
||||
) -> Result<Self::Dial, TransportError<Self::Error>> {
|
||||
let (socket_addr, _version, peer_id) = self.remote_multiaddr_to_socketaddr(addr.clone())?;
|
||||
let (socket_addr, _version, peer_id) =
|
||||
self.remote_multiaddr_to_socketaddr(addr.clone(), true)?;
|
||||
let peer_id = peer_id.ok_or(TransportError::MultiaddrNotSupported(addr))?;
|
||||
|
||||
let endpoint_channel = self
|
||||
let socket = self
|
||||
.eligible_listener(&socket_addr)
|
||||
.ok_or(TransportError::Other(
|
||||
Error::NoActiveListenerForDialAsListener,
|
||||
))?
|
||||
.endpoint_channel
|
||||
.clone();
|
||||
.try_clone_socket()
|
||||
.map_err(Self::Error::from)?;
|
||||
|
||||
let hole_puncher = hole_puncher::<P>(endpoint_channel, socket_addr, self.handshake_timeout);
|
||||
let hole_puncher = hole_puncher::<P>(socket, socket_addr, self.handshake_timeout);
|
||||
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
|
||||
@ -274,19 +335,6 @@ impl<P: Provider> Transport for GenTransport<P> {
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
|
||||
let mut errored = Vec::new();
|
||||
for (key, dialer) in &mut self.dialer {
|
||||
if let Poll::Ready(_error) = dialer.poll(cx) {
|
||||
errored.push(*key);
|
||||
}
|
||||
}
|
||||
|
||||
for key in errored {
|
||||
// Endpoint driver of dialer crashed.
|
||||
// Drop dialer and all pending dials so that the connection receiver is notified.
|
||||
self.dialer.remove(&key);
|
||||
}
|
||||
|
||||
while let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) {
|
||||
match ev {
|
||||
TransportEvent::Incoming {
|
||||
@ -331,112 +379,22 @@ impl From<Error> for TransportError<Error> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Dialer for addresses if no matching listener exists.
|
||||
#[derive(Debug)]
|
||||
struct Dialer {
|
||||
/// Channel to the [`crate::endpoint::Driver`] that
|
||||
/// is driving the endpoint.
|
||||
endpoint_channel: endpoint::Channel,
|
||||
/// Queued dials for the endpoint.
|
||||
state: DialerState,
|
||||
}
|
||||
|
||||
impl Dialer {
|
||||
fn new<P: Provider>(
|
||||
config: QuinnConfig,
|
||||
socket_family: SocketFamily,
|
||||
) -> Result<Self, TransportError<Error>> {
|
||||
let endpoint_channel = endpoint::Channel::new_dialer::<P>(config, socket_family)
|
||||
.map_err(TransportError::Other)?;
|
||||
Ok(Dialer {
|
||||
endpoint_channel,
|
||||
state: DialerState::default(),
|
||||
})
|
||||
}
|
||||
|
||||
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Error> {
|
||||
self.state.poll(&mut self.endpoint_channel, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Dialer {
|
||||
fn drop(&mut self) {
|
||||
self.endpoint_channel.send_on_drop(ToEndpoint::Decoupled);
|
||||
}
|
||||
}
|
||||
|
||||
/// Pending dials to be sent to the endpoint was the [`endpoint::Channel`]
|
||||
/// has capacity
|
||||
#[derive(Default, Debug)]
|
||||
struct DialerState {
|
||||
pending_dials: VecDeque<ToEndpoint>,
|
||||
waker: Option<Waker>,
|
||||
}
|
||||
|
||||
impl DialerState {
|
||||
fn new_dial(
|
||||
&mut self,
|
||||
address: SocketAddr,
|
||||
timeout: Duration,
|
||||
version: ProtocolVersion,
|
||||
) -> BoxFuture<'static, Result<(PeerId, Connection), Error>> {
|
||||
let (rx, tx) = oneshot::channel();
|
||||
|
||||
let message = ToEndpoint::Dial {
|
||||
addr: address,
|
||||
result: rx,
|
||||
version,
|
||||
};
|
||||
|
||||
self.pending_dials.push_back(message);
|
||||
|
||||
if let Some(waker) = self.waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
|
||||
async move {
|
||||
// Our oneshot getting dropped means the message didn't make it to the endpoint driver.
|
||||
let connection = tx.await.map_err(|_| Error::EndpointDriverCrashed)??;
|
||||
let (peer, connection) = Connecting::new(connection, timeout).await?;
|
||||
|
||||
Ok((peer, connection))
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// Send all pending dials into the given [`endpoint::Channel`].
|
||||
///
|
||||
/// This only ever returns [`Poll::Pending`], or an error in case the channel is closed.
|
||||
fn poll(&mut self, channel: &mut endpoint::Channel, cx: &mut Context<'_>) -> Poll<Error> {
|
||||
while let Some(to_endpoint) = self.pending_dials.pop_front() {
|
||||
match channel.try_send(to_endpoint, cx) {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(to_endpoint)) => {
|
||||
self.pending_dials.push_front(to_endpoint);
|
||||
break;
|
||||
}
|
||||
Err(endpoint::Disconnected {}) => return Poll::Ready(Error::EndpointDriverCrashed),
|
||||
}
|
||||
}
|
||||
self.waker = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
/// Listener for incoming connections.
|
||||
struct Listener<P: Provider> {
|
||||
/// Id of the listener.
|
||||
listener_id: ListenerId,
|
||||
|
||||
/// Version of the supported quic protocol.
|
||||
version: ProtocolVersion,
|
||||
|
||||
/// Channel to the endpoint to initiate dials.
|
||||
endpoint_channel: endpoint::Channel,
|
||||
/// Queued dials.
|
||||
dialer_state: DialerState,
|
||||
/// Endpoint
|
||||
endpoint: quinn::Endpoint,
|
||||
|
||||
/// Channel where new connections are being sent.
|
||||
new_connections_rx: mpsc::Receiver<Connection>,
|
||||
/// An underlying copy of the socket to be able to hole punch with
|
||||
socket: UdpSocket,
|
||||
|
||||
/// A future to poll new incoming connections.
|
||||
accept: BoxFuture<'static, Option<quinn::Connecting>>,
|
||||
/// Timeout for connection establishment on inbound connections.
|
||||
handshake_timeout: Duration,
|
||||
|
||||
@ -458,38 +416,39 @@ struct Listener<P: Provider> {
|
||||
impl<P: Provider> Listener<P> {
|
||||
fn new(
|
||||
listener_id: ListenerId,
|
||||
socket_addr: SocketAddr,
|
||||
config: QuinnConfig,
|
||||
socket: UdpSocket,
|
||||
endpoint: quinn::Endpoint,
|
||||
handshake_timeout: Duration,
|
||||
version: ProtocolVersion,
|
||||
) -> Result<Self, Error> {
|
||||
let (endpoint_channel, new_connections_rx) =
|
||||
endpoint::Channel::new_bidirectional::<P>(config, socket_addr)?;
|
||||
|
||||
let if_watcher;
|
||||
let pending_event;
|
||||
if socket_addr.ip().is_unspecified() {
|
||||
let local_addr = socket.local_addr()?;
|
||||
if local_addr.ip().is_unspecified() {
|
||||
if_watcher = Some(P::new_if_watcher()?);
|
||||
pending_event = None;
|
||||
} else {
|
||||
if_watcher = None;
|
||||
let ma = socketaddr_to_multiaddr(endpoint_channel.socket_addr(), version);
|
||||
let ma = socketaddr_to_multiaddr(&local_addr, version);
|
||||
pending_event = Some(TransportEvent::NewAddress {
|
||||
listener_id,
|
||||
listen_addr: ma,
|
||||
})
|
||||
}
|
||||
|
||||
let endpoint_c = endpoint.clone();
|
||||
let accept = async move { endpoint_c.accept().await }.boxed();
|
||||
|
||||
Ok(Listener {
|
||||
endpoint_channel,
|
||||
endpoint,
|
||||
socket,
|
||||
accept,
|
||||
listener_id,
|
||||
version,
|
||||
new_connections_rx,
|
||||
handshake_timeout,
|
||||
if_watcher,
|
||||
is_closed: false,
|
||||
pending_event,
|
||||
dialer_state: DialerState::default(),
|
||||
close_listener_waker: None,
|
||||
})
|
||||
}
|
||||
@ -500,6 +459,7 @@ impl<P: Provider> Listener<P> {
|
||||
if self.is_closed {
|
||||
return;
|
||||
}
|
||||
self.endpoint.close(From::from(0u32), &[]);
|
||||
self.pending_event = Some(TransportEvent::ListenerClosed {
|
||||
listener_id: self.listener_id,
|
||||
reason,
|
||||
@ -512,8 +472,20 @@ impl<P: Provider> Listener<P> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Clone underlying socket (for hole punching).
|
||||
fn try_clone_socket(&self) -> std::io::Result<UdpSocket> {
|
||||
self.socket.try_clone()
|
||||
}
|
||||
|
||||
fn socket_addr(&self) -> SocketAddr {
|
||||
self.socket
|
||||
.local_addr()
|
||||
.expect("Cannot fail because the socket is bound")
|
||||
}
|
||||
|
||||
/// Poll for a next If Event.
|
||||
fn poll_if_addr(&mut self, cx: &mut Context<'_>) -> Poll<<Self as Stream>::Item> {
|
||||
let endpoint_addr = self.socket_addr();
|
||||
let if_watcher = match self.if_watcher.as_mut() {
|
||||
Some(iw) => iw,
|
||||
None => return Poll::Pending,
|
||||
@ -521,11 +493,9 @@ impl<P: Provider> Listener<P> {
|
||||
loop {
|
||||
match ready!(P::poll_if_event(if_watcher, cx)) {
|
||||
Ok(IfEvent::Up(inet)) => {
|
||||
if let Some(listen_addr) = ip_to_listenaddr(
|
||||
self.endpoint_channel.socket_addr(),
|
||||
inet.addr(),
|
||||
self.version,
|
||||
) {
|
||||
if let Some(listen_addr) =
|
||||
ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
|
||||
{
|
||||
log::debug!("New listen address: {}", listen_addr);
|
||||
return Poll::Ready(TransportEvent::NewAddress {
|
||||
listener_id: self.listener_id,
|
||||
@ -534,11 +504,9 @@ impl<P: Provider> Listener<P> {
|
||||
}
|
||||
}
|
||||
Ok(IfEvent::Down(inet)) => {
|
||||
if let Some(listen_addr) = ip_to_listenaddr(
|
||||
self.endpoint_channel.socket_addr(),
|
||||
inet.addr(),
|
||||
self.version,
|
||||
) {
|
||||
if let Some(listen_addr) =
|
||||
ip_to_listenaddr(&endpoint_addr, inet.addr(), self.version)
|
||||
{
|
||||
log::debug!("Expired listen address: {}", listen_addr);
|
||||
return Poll::Ready(TransportEvent::AddressExpired {
|
||||
listener_id: self.listener_id,
|
||||
@ -555,21 +523,10 @@ impl<P: Provider> Listener<P> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Poll [`DialerState`] to initiate requested dials.
|
||||
fn poll_dialer(&mut self, cx: &mut Context<'_>) -> Poll<Error> {
|
||||
let Self {
|
||||
dialer_state,
|
||||
endpoint_channel,
|
||||
..
|
||||
} = self;
|
||||
|
||||
dialer_state.poll(endpoint_channel, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: Provider> Stream for Listener<P> {
|
||||
type Item = TransportEvent<Connecting, Error>;
|
||||
type Item = TransportEvent<<GenTransport<P> as Transport>::ListenerUpgrade, Error>;
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
loop {
|
||||
if let Some(event) = self.pending_event.take() {
|
||||
@ -581,17 +538,18 @@ impl<P: Provider> Stream for Listener<P> {
|
||||
if let Poll::Ready(event) = self.poll_if_addr(cx) {
|
||||
return Poll::Ready(Some(event));
|
||||
}
|
||||
if let Poll::Ready(error) = self.poll_dialer(cx) {
|
||||
self.close(Err(error));
|
||||
continue;
|
||||
}
|
||||
match self.new_connections_rx.poll_next_unpin(cx) {
|
||||
Poll::Ready(Some(connection)) => {
|
||||
let local_addr = socketaddr_to_multiaddr(connection.local_addr(), self.version);
|
||||
let send_back_addr =
|
||||
socketaddr_to_multiaddr(&connection.remote_addr(), self.version);
|
||||
|
||||
match self.accept.poll_unpin(cx) {
|
||||
Poll::Ready(Some(connecting)) => {
|
||||
let endpoint = self.endpoint.clone();
|
||||
self.accept = async move { endpoint.accept().await }.boxed();
|
||||
|
||||
let local_addr = socketaddr_to_multiaddr(&self.socket_addr(), self.version);
|
||||
let remote_addr = connecting.remote_address();
|
||||
let send_back_addr = socketaddr_to_multiaddr(&remote_addr, self.version);
|
||||
|
||||
let event = TransportEvent::Incoming {
|
||||
upgrade: Connecting::new(connection, self.handshake_timeout),
|
||||
upgrade: Connecting::new(connecting, self.handshake_timeout),
|
||||
local_addr,
|
||||
send_back_addr,
|
||||
listener_id: self.listener_id,
|
||||
@ -616,9 +574,6 @@ impl<P: Provider> fmt::Debug for Listener<P> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("Listener")
|
||||
.field("listener_id", &self.listener_id)
|
||||
.field("endpoint_channel", &self.endpoint_channel)
|
||||
.field("dialer_state", &self.dialer_state)
|
||||
.field("new_connections_rx", &self.new_connections_rx)
|
||||
.field("handshake_timeout", &self.handshake_timeout)
|
||||
.field("is_closed", &self.is_closed)
|
||||
.field("pending_event", &self.pending_event)
|
||||
@ -626,12 +581,6 @@ impl<P: Provider> fmt::Debug for Listener<P> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<P: Provider> Drop for Listener<P> {
|
||||
fn drop(&mut self) {
|
||||
self.endpoint_channel.send_on_drop(ToEndpoint::Decoupled);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum ProtocolVersion {
|
||||
V1, // i.e. RFC9000
|
||||
@ -766,7 +715,6 @@ fn socketaddr_to_multiaddr(socket_addr: &SocketAddr, version: ProtocolVersion) -
|
||||
#[cfg(any(feature = "async-std", feature = "tokio"))]
|
||||
mod test {
|
||||
use futures::future::poll_fn;
|
||||
use futures_timer::Delay;
|
||||
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
|
||||
|
||||
use super::*;
|
||||
@ -882,15 +830,6 @@ mod test {
|
||||
.listen_on(id, "/ip4/0.0.0.0/udp/0/quic-v1".parse().unwrap())
|
||||
.unwrap();
|
||||
|
||||
// Copy channel to use it later.
|
||||
let mut channel = transport
|
||||
.listeners
|
||||
.iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
.endpoint_channel
|
||||
.clone();
|
||||
|
||||
match poll_fn(|cx| Pin::new(&mut transport).as_mut().poll(cx)).await {
|
||||
TransportEvent::NewAddress {
|
||||
listener_id,
|
||||
@ -923,14 +862,6 @@ mod test {
|
||||
.now_or_never()
|
||||
.is_none());
|
||||
assert!(transport.listeners.is_empty());
|
||||
|
||||
// Check that the [`Driver`] has shut down.
|
||||
Delay::new(Duration::from_millis(10)).await;
|
||||
poll_fn(|cx| {
|
||||
assert!(channel.try_send(ToEndpoint::Decoupled, cx).is_err());
|
||||
Poll::Ready(())
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
@ -945,32 +876,9 @@ mod test {
|
||||
.dial("/ip4/123.45.67.8/udp/1234/quic-v1".parse().unwrap())
|
||||
.unwrap();
|
||||
|
||||
// Expect a dialer and its background task to exist.
|
||||
let mut channel = transport
|
||||
.dialer
|
||||
.get(&SocketFamily::Ipv4)
|
||||
.unwrap()
|
||||
.endpoint_channel
|
||||
.clone();
|
||||
assert!(transport.dialer.contains_key(&SocketFamily::Ipv4));
|
||||
assert!(!transport.dialer.contains_key(&SocketFamily::Ipv6));
|
||||
|
||||
// Send dummy dial to check that the endpoint driver is running.
|
||||
poll_fn(|cx| {
|
||||
let (tx, _) = oneshot::channel();
|
||||
let _ = channel
|
||||
.try_send(
|
||||
ToEndpoint::Dial {
|
||||
addr: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0),
|
||||
result: tx,
|
||||
version: ProtocolVersion::V1,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
.unwrap();
|
||||
Poll::Ready(())
|
||||
})
|
||||
.await;
|
||||
|
||||
// Start listening so that the dialer and driver are dropped.
|
||||
transport
|
||||
.listen_on(
|
||||
@ -979,13 +887,5 @@ mod test {
|
||||
)
|
||||
.unwrap();
|
||||
assert!(!transport.dialer.contains_key(&SocketFamily::Ipv4));
|
||||
|
||||
// Check that the [`Driver`] has shut down.
|
||||
Delay::new(Duration::from_millis(10)).await;
|
||||
poll_fn(|cx| {
|
||||
assert!(channel.try_send(ToEndpoint::Decoupled, cx).is_err());
|
||||
Poll::Ready(())
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
@ -428,7 +428,7 @@ async fn smoke<P: Provider>() {
|
||||
assert_eq!(b_connected, a_peer_id);
|
||||
}
|
||||
|
||||
async fn build_streams<P: Provider>() -> (SubstreamBox, SubstreamBox) {
|
||||
async fn build_streams<P: Provider + Spawn>() -> (SubstreamBox, SubstreamBox) {
|
||||
let (_, mut a_transport) = create_default_transport::<P>();
|
||||
let (_, mut b_transport) = create_default_transport::<P>();
|
||||
|
||||
@ -522,7 +522,7 @@ async fn start_listening(transport: &mut Boxed<(PeerId, StreamMuxerBox)>, addr:
|
||||
}
|
||||
}
|
||||
|
||||
fn prop<P: Provider + BlockOn>(
|
||||
fn prop<P: Provider + BlockOn + Spawn>(
|
||||
number_listeners: NonZeroU8,
|
||||
number_streams: NonZeroU8,
|
||||
) -> quickcheck::TestResult {
|
||||
@ -599,7 +599,7 @@ fn prop<P: Provider + BlockOn>(
|
||||
quickcheck::TestResult::passed()
|
||||
}
|
||||
|
||||
async fn answer_inbound_streams<P: Provider, const BUFFER_SIZE: usize>(
|
||||
async fn answer_inbound_streams<P: Provider + Spawn, const BUFFER_SIZE: usize>(
|
||||
mut connection: StreamMuxerBox,
|
||||
) {
|
||||
loop {
|
||||
@ -634,7 +634,7 @@ async fn answer_inbound_streams<P: Provider, const BUFFER_SIZE: usize>(
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_outbound_streams<P: Provider, const BUFFER_SIZE: usize>(
|
||||
async fn open_outbound_streams<P: Provider + Spawn, const BUFFER_SIZE: usize>(
|
||||
mut connection: StreamMuxerBox,
|
||||
number_streams: usize,
|
||||
completed_streams_tx: mpsc::Sender<()>,
|
||||
@ -740,3 +740,22 @@ impl BlockOn for libp2p_quic::tokio::Provider {
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
trait Spawn {
|
||||
/// Run the given future in the background until it ends.
|
||||
fn spawn(future: impl Future<Output = ()> + Send + 'static);
|
||||
}
|
||||
|
||||
#[cfg(feature = "async-std")]
|
||||
impl Spawn for libp2p_quic::async_std::Provider {
|
||||
fn spawn(future: impl Future<Output = ()> + Send + 'static) {
|
||||
async_std::task::spawn(future);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio")]
|
||||
impl Spawn for libp2p_quic::tokio::Provider {
|
||||
fn spawn(future: impl Future<Output = ()> + Send + 'static) {
|
||||
tokio::spawn(future);
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user