mirror of
https://github.com/fluencelabs/rust-libp2p
synced 2025-06-29 17:51:35 +00:00
Upgrade websocket transport to soketto 0.3.0. (#1266)
Upgrade websocket transport to soketto 0.3.0.
This commit is contained in:
@ -22,7 +22,7 @@ log = "0.4"
|
|||||||
multiaddr = { package = "parity-multiaddr", version = "0.5.0", path = "../misc/multiaddr" }
|
multiaddr = { package = "parity-multiaddr", version = "0.5.0", path = "../misc/multiaddr" }
|
||||||
multihash = { package = "parity-multihash", version = "0.1.0", path = "../misc/multihash" }
|
multihash = { package = "parity-multihash", version = "0.1.0", path = "../misc/multihash" }
|
||||||
multistream-select = { version = "0.5.0", path = "../misc/multistream-select" }
|
multistream-select = { version = "0.5.0", path = "../misc/multistream-select" }
|
||||||
futures-preview = { version = "0.3.0-alpha.18", features = ["compat", "io-compat"] }
|
futures-preview = { version = "= 0.3.0-alpha.18", features = ["compat", "io-compat"] }
|
||||||
parking_lot = "0.8"
|
parking_lot = "0.8"
|
||||||
protobuf = "2.3"
|
protobuf = "2.3"
|
||||||
quick-error = "1.2"
|
quick-error = "1.2"
|
||||||
|
@ -10,4 +10,4 @@ keywords = ["networking"]
|
|||||||
categories = ["network-programming", "asynchronous"]
|
categories = ["network-programming", "asynchronous"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
futures-preview = "0.3.0-alpha.18"
|
futures-preview = "= 0.3.0-alpha.18"
|
||||||
|
@ -28,7 +28,7 @@
|
|||||||
//! > not at all specific to libp2p.
|
//! > not at all specific to libp2p.
|
||||||
|
|
||||||
use futures::{prelude::*, io::Initializer};
|
use futures::{prelude::*, io::Initializer};
|
||||||
use std::{cmp, io, marker::PhantomData, pin::Pin, task::Context, task::Poll};
|
use std::{cmp, io, pin::Pin, task::Context, task::Poll};
|
||||||
|
|
||||||
/// Wraps around a `Stream + Sink` whose items are buffers. Implements `AsyncRead` and `AsyncWrite`.
|
/// Wraps around a `Stream + Sink` whose items are buffers. Implements `AsyncRead` and `AsyncWrite`.
|
||||||
///
|
///
|
||||||
|
@ -10,15 +10,14 @@ keywords = ["peer-to-peer", "libp2p", "networking"]
|
|||||||
categories = ["network-programming", "asynchronous"]
|
categories = ["network-programming", "asynchronous"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bytes = "0.4.6"
|
bytes = "0.4.12"
|
||||||
futures-preview = { version = "0.3.0-alpha.18", features = ["compat"] }
|
either = "1.5.3"
|
||||||
futures_codec = "0.2.0"
|
futures-preview = "= 0.3.0-alpha.18"
|
||||||
|
futures-rustls = "0.12.0-alpha"
|
||||||
libp2p-core = { version = "0.12.0", path = "../../core" }
|
libp2p-core = { version = "0.12.0", path = "../../core" }
|
||||||
log = "0.4.1"
|
log = "0.4.8"
|
||||||
rw-stream-sink = { version = "0.1.1", path = "../../misc/rw-stream-sink" }
|
rw-stream-sink = { version = "0.1.1", path = "../../misc/rw-stream-sink" }
|
||||||
tokio-io = "0.1.12"
|
soketto = { git = "https://github.com/paritytech/soketto.git", branch = "develop", features = ["deflate"] }
|
||||||
tokio-rustls = "0.10.0-alpha.3"
|
|
||||||
soketto = { version = "0.2.3", features = ["deflate"] }
|
|
||||||
url = "1.7.2"
|
url = "1.7.2"
|
||||||
webpki-roots = "0.16.0"
|
webpki-roots = "0.16.0"
|
||||||
|
|
||||||
|
@ -20,8 +20,9 @@
|
|||||||
|
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
use crate::{error::Error, tls};
|
use crate::{error::Error, tls};
|
||||||
use futures::{future::{self, Either, Loop}, prelude::*, ready};
|
use either::Either;
|
||||||
use futures_codec::{Framed, FramedParts};
|
use futures::{prelude::*, ready};
|
||||||
|
use futures_rustls::{client, server, webpki};
|
||||||
use libp2p_core::{
|
use libp2p_core::{
|
||||||
Transport,
|
Transport,
|
||||||
either::EitherOutput,
|
either::EitherOutput,
|
||||||
@ -29,19 +30,12 @@ use libp2p_core::{
|
|||||||
transport::{ListenerEvent, TransportError}
|
transport::{ListenerEvent, TransportError}
|
||||||
};
|
};
|
||||||
use log::{debug, trace};
|
use log::{debug, trace};
|
||||||
use tokio_rustls::{client, server};
|
use soketto::{connection::{self, Connection}, extension::deflate::Deflate, handshake};
|
||||||
use soketto::{
|
use std::{io, pin::Pin, task::Context, task::Poll};
|
||||||
base,
|
|
||||||
connection::{Connection, Mode},
|
|
||||||
extension::deflate::Deflate,
|
|
||||||
handshake::{self, Redirect, Response}
|
|
||||||
};
|
|
||||||
use std::{convert::TryFrom, io, pin::Pin, task::Context, task::Poll};
|
|
||||||
use tokio_rustls::webpki;
|
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
|
||||||
/// Max. number of payload bytes of a single frame.
|
/// Max. number of payload bytes of a single frame.
|
||||||
const MAX_DATA_SIZE: u64 = 256 * 1024 * 1024;
|
const MAX_DATA_SIZE: usize = 256 * 1024 * 1024;
|
||||||
|
|
||||||
/// A Websocket transport whose output type is a [`Stream`] and [`Sink`] of
|
/// A Websocket transport whose output type is a [`Stream`] and [`Sink`] of
|
||||||
/// frame payloads which does not implement [`AsyncRead`] or
|
/// frame payloads which does not implement [`AsyncRead`] or
|
||||||
@ -49,7 +43,7 @@ const MAX_DATA_SIZE: u64 = 256 * 1024 * 1024;
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct WsConfig<T> {
|
pub struct WsConfig<T> {
|
||||||
transport: T,
|
transport: T,
|
||||||
max_data_size: u64,
|
max_data_size: usize,
|
||||||
tls_config: tls::Config,
|
tls_config: tls::Config,
|
||||||
max_redirects: u8,
|
max_redirects: u8,
|
||||||
use_deflate: bool
|
use_deflate: bool
|
||||||
@ -79,12 +73,12 @@ impl<T> WsConfig<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the max. frame data size we support.
|
/// Get the max. frame data size we support.
|
||||||
pub fn max_data_size(&self) -> u64 {
|
pub fn max_data_size(&self) -> usize {
|
||||||
self.max_data_size
|
self.max_data_size
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the max. frame data size we support.
|
/// Set the max. frame data size we support.
|
||||||
pub fn set_max_data_size(&mut self, size: u64) -> &mut Self {
|
pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
|
||||||
self.max_data_size = size;
|
self.max_data_size = size;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@ -102,14 +96,16 @@ impl<T> WsConfig<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TlsOrPlain<T> = EitherOutput<EitherOutput<client::TlsStream<T>, server::TlsStream<T>>, T>;
|
||||||
|
|
||||||
impl<T> Transport for WsConfig<T>
|
impl<T> Transport for WsConfig<T>
|
||||||
where
|
where
|
||||||
T: Transport + Send + Clone + 'static,
|
T: Transport + Send + Clone + 'static,
|
||||||
T::Error: Send + 'static,
|
T::Error: Send + 'static,
|
||||||
T::Dial: Send + 'static,
|
T::Dial: Send + 'static,
|
||||||
T::Listener: Send + 'static,
|
T::Listener: Send + Unpin + 'static,
|
||||||
T::ListenerUpgrade: Send + 'static,
|
T::ListenerUpgrade: Send + 'static,
|
||||||
T::Output: AsyncRead + AsyncWrite + Send + 'static
|
T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static
|
||||||
{
|
{
|
||||||
type Output = BytesConnection<T::Output>;
|
type Output = BytesConnection<T::Output>;
|
||||||
type Error = Error<T::Error>;
|
type Error = Error<T::Error>;
|
||||||
@ -138,10 +134,10 @@ where
|
|||||||
let tls_config = self.tls_config;
|
let tls_config = self.tls_config;
|
||||||
let max_size = self.max_data_size;
|
let max_size = self.max_data_size;
|
||||||
let use_deflate = self.use_deflate;
|
let use_deflate = self.use_deflate;
|
||||||
let listen = self.transport.listen_on(inner_addr)
|
let transport = self.transport.listen_on(inner_addr).map_err(|e| e.map(Error::Transport))?;
|
||||||
.map_err(|e| e.map(Error::Transport))?
|
let listen = transport
|
||||||
.map_err(Error::Transport)
|
.map_err(Error::Transport)
|
||||||
.map(move |event| match event {
|
.map_ok(move |event| match event {
|
||||||
ListenerEvent::NewAddress(mut a) => {
|
ListenerEvent::NewAddress(mut a) => {
|
||||||
a = a.with(proto.clone());
|
a = a.with(proto.clone());
|
||||||
debug!("Listening on {}", a);
|
debug!("Listening on {}", a);
|
||||||
@ -157,60 +153,76 @@ where
|
|||||||
let remote1 = remote_addr.clone(); // used for logging
|
let remote1 = remote_addr.clone(); // used for logging
|
||||||
let remote2 = remote_addr.clone(); // used for logging
|
let remote2 = remote_addr.clone(); // used for logging
|
||||||
let tls_config = tls_config.clone();
|
let tls_config = tls_config.clone();
|
||||||
let upgraded = upgrade.map_err(Error::Transport)
|
|
||||||
.and_then(move |stream| {
|
let upgrade = async move {
|
||||||
trace!("incoming connection from {}", remote1);
|
let stream = upgrade.map_err(Error::Transport).await?;
|
||||||
|
trace!("incoming connection from {}", remote1);
|
||||||
|
|
||||||
|
let stream =
|
||||||
if use_tls { // begin TLS session
|
if use_tls { // begin TLS session
|
||||||
let server = tls_config.server.expect("for use_tls we checked server");
|
let server = tls_config
|
||||||
|
.server
|
||||||
|
.expect("for use_tls we checked server is not none");
|
||||||
|
|
||||||
trace!("awaiting TLS handshake with {}", remote1);
|
trace!("awaiting TLS handshake with {}", remote1);
|
||||||
let future = server.accept(stream)
|
|
||||||
|
let stream = server.accept(stream)
|
||||||
.map_err(move |e| {
|
.map_err(move |e| {
|
||||||
debug!("TLS handshake with {} failed: {}", remote1, e);
|
debug!("TLS handshake with {} failed: {}", remote1, e);
|
||||||
Error::Tls(tls::Error::from(e))
|
Error::Tls(tls::Error::from(e))
|
||||||
})
|
})
|
||||||
.map(|s| EitherOutput::First(EitherOutput::Second(s)));
|
.await?;
|
||||||
Either::Left(future)
|
|
||||||
|
let stream: TlsOrPlain<_> =
|
||||||
|
EitherOutput::First(EitherOutput::Second(stream));
|
||||||
|
|
||||||
|
stream
|
||||||
} else { // continue with plain stream
|
} else { // continue with plain stream
|
||||||
Either::Right(future::ok(EitherOutput::Second(stream)))
|
EitherOutput::Second(stream)
|
||||||
}
|
};
|
||||||
})
|
|
||||||
.and_then(move |stream| {
|
trace!("receiving websocket handshake request from {}", remote2);
|
||||||
trace!("receiving websocket handshake request from {}", remote2);
|
|
||||||
let mut s = handshake::Server::new();
|
let mut server = handshake::Server::new(stream);
|
||||||
if use_deflate {
|
|
||||||
s.add_extension(Box::new(Deflate::new(Mode::Server)));
|
if use_deflate {
|
||||||
}
|
server.add_extension(Box::new(Deflate::new(connection::Mode::Server)));
|
||||||
Framed::new(stream, s)
|
}
|
||||||
.into_future()
|
|
||||||
.map_err(|(e, _framed)| Error::Handshake(Box::new(e)))
|
let ws_key = {
|
||||||
.and_then(move |(request, framed)| {
|
let request = server.receive_request()
|
||||||
if let Some(r) = request {
|
.map_err(|e| Error::Handshake(Box::new(e)))
|
||||||
trace!("accepting websocket handshake request from {}", remote2);
|
.await?;
|
||||||
let key = Vec::from(r.key());
|
request.into_key()
|
||||||
Either::Left(framed.send(Ok(handshake::Accept::new(key)))
|
};
|
||||||
.map_err(|e| Error::Base(Box::new(e)))
|
|
||||||
.map(move |f| {
|
trace!("accepting websocket handshake request from {}", remote2);
|
||||||
trace!("websocket handshake with {} successful", remote2);
|
|
||||||
let (mut handshake, mut c) =
|
let response =
|
||||||
new_connection(f, max_size, Mode::Server);
|
handshake::server::Response::Accept {
|
||||||
c.add_extensions(handshake.drain_extensions());
|
key: &ws_key,
|
||||||
BytesConnection { inner: c }
|
protocol: None
|
||||||
}))
|
};
|
||||||
} else {
|
|
||||||
debug!("connection to {} terminated during handshake", remote2);
|
server.send_response(&response)
|
||||||
let e: io::Error = io::ErrorKind::ConnectionAborted.into();
|
.map_err(|e| Error::Handshake(Box::new(e)))
|
||||||
Either::Right(future::err(Error::Handshake(Box::new(e))))
|
.await?;
|
||||||
}
|
|
||||||
})
|
let mut conn = server.into_connection();
|
||||||
});
|
conn.set_max_message_size(max_size);
|
||||||
|
conn.set_max_frame_size(max_size);
|
||||||
|
|
||||||
|
Ok(BytesConnection(conn))
|
||||||
|
};
|
||||||
|
|
||||||
ListenerEvent::Upgrade {
|
ListenerEvent::Upgrade {
|
||||||
upgrade: Box::new(upgraded) as Box<dyn Future<Item = _, Error = _> + Send>,
|
upgrade: Box::pin(upgrade) as Pin<Box<dyn Future<Output = _> + Send>>,
|
||||||
local_addr,
|
local_addr,
|
||||||
remote_addr
|
remote_addr
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
Ok(Box::pin(listen) as Box<_>)
|
Ok(Box::pin(listen))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
|
fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
|
||||||
@ -221,121 +233,110 @@ where
|
|||||||
debug!("{} is not a websocket multiaddr", addr);
|
debug!("{} is not a websocket multiaddr", addr);
|
||||||
return Err(TransportError::MultiaddrNotSupported(addr))
|
return Err(TransportError::MultiaddrNotSupported(addr))
|
||||||
}
|
}
|
||||||
|
|
||||||
// We are looping here in order to follow redirects (if any):
|
// We are looping here in order to follow redirects (if any):
|
||||||
let max_redirects = self.max_redirects;
|
let mut remaining_redirects = self.max_redirects;
|
||||||
let future = future::loop_fn((addr, self, max_redirects), |(addr, cfg, remaining)| {
|
let mut addr = addr;
|
||||||
dial(addr, cfg.clone()).and_then(move |result| match result {
|
let future = async move {
|
||||||
Either::Left(redirect) => {
|
loop {
|
||||||
if remaining == 0 {
|
let this = self.clone();
|
||||||
debug!("too many redirects");
|
match this.dial_once(addr).await {
|
||||||
return Err(Error::TooManyRedirects)
|
Ok(Either::Left(redirect)) => {
|
||||||
|
if remaining_redirects == 0 {
|
||||||
|
debug!("too many redirects");
|
||||||
|
return Err(Error::TooManyRedirects)
|
||||||
|
}
|
||||||
|
remaining_redirects -= 1;
|
||||||
|
addr = location_to_multiaddr(&redirect)?
|
||||||
}
|
}
|
||||||
let a = location_to_multiaddr(redirect.location())?;
|
Ok(Either::Right(conn)) => return Ok(conn),
|
||||||
Ok(Loop::Continue((a, cfg, remaining - 1)))
|
Err(e) => return Err(e)
|
||||||
}
|
}
|
||||||
Either::Right(conn) => Ok(Loop::Break(conn))
|
}
|
||||||
})
|
};
|
||||||
});
|
|
||||||
Ok(Box::pin(future) as Box<_>)
|
Ok(Box::pin(future))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Attempty to dial the given address and perform a websocket handshake.
|
impl<T> WsConfig<T>
|
||||||
fn dial<T>(address: Multiaddr, config: WsConfig<T>)
|
|
||||||
-> impl Future<Output = Result<Either<Redirect, BytesConnection<T::Output>>, Error<T::Error>>>
|
|
||||||
where
|
where
|
||||||
T: Transport,
|
T: Transport,
|
||||||
T::Output: AsyncRead + AsyncWrite
|
T::Output: AsyncRead + AsyncWrite + Unpin + 'static
|
||||||
{
|
{
|
||||||
trace!("dial address: {}", address);
|
/// Attempty to dial the given address and perform a websocket handshake.
|
||||||
|
async fn dial_once(self, address: Multiaddr) -> Result<Either<String, BytesConnection<T::Output>>, Error<T::Error>> {
|
||||||
|
trace!("dial address: {}", address);
|
||||||
|
|
||||||
let WsConfig { transport, max_data_size, tls_config, .. } = config;
|
let (host_port, dns_name) = host_and_dnsname(&address)?;
|
||||||
|
|
||||||
let (host_port, dns_name) = match host_and_dnsname(&address) {
|
let mut inner_addr = address.clone();
|
||||||
Ok(x) => x,
|
|
||||||
Err(e) => return Either::Left(future::err(e))
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut inner_addr = address.clone();
|
let (use_tls, path) =
|
||||||
|
match inner_addr.pop() {
|
||||||
|
Some(Protocol::Ws(path)) => (false, path),
|
||||||
|
Some(Protocol::Wss(path)) => {
|
||||||
|
if dns_name.is_none() {
|
||||||
|
debug!("no DNS name in {}", address);
|
||||||
|
return Err(Error::InvalidMultiaddr(address))
|
||||||
|
}
|
||||||
|
(true, path)
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
debug!("{} is not a websocket multiaddr", address);
|
||||||
|
return Err(Error::InvalidMultiaddr(address))
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let (use_tls, path) = match inner_addr.pop() {
|
let dial = self.transport.dial(inner_addr)
|
||||||
Some(Protocol::Ws(path)) => (false, path),
|
.map_err(|e| match e {
|
||||||
Some(Protocol::Wss(path)) => {
|
TransportError::MultiaddrNotSupported(a) => Error::InvalidMultiaddr(a),
|
||||||
if dns_name.is_none() {
|
TransportError::Other(e) => Error::Transport(e)
|
||||||
debug!("no DNS name in {}", address);
|
})?;
|
||||||
return Either::Left(future::err(Error::InvalidMultiaddr(address)))
|
|
||||||
}
|
|
||||||
(true, path)
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
debug!("{} is not a websocket multiaddr", address);
|
|
||||||
return Either::Left(future::err(Error::InvalidMultiaddr(address)))
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let dial = match transport.dial(inner_addr) {
|
let stream = dial.map_err(Error::Transport).await?;
|
||||||
Ok(dial) => dial,
|
trace!("connected to {}", address);
|
||||||
Err(TransportError::MultiaddrNotSupported(a)) =>
|
|
||||||
return Either::Left(future::err(Error::InvalidMultiaddr(a))),
|
|
||||||
Err(TransportError::Other(e)) =>
|
|
||||||
return Either::Left(future::err(Error::Transport(e)))
|
|
||||||
};
|
|
||||||
|
|
||||||
let address1 = address.clone(); // used for logging
|
let stream =
|
||||||
let address2 = address.clone(); // used for logging
|
|
||||||
let use_deflate = config.use_deflate;
|
|
||||||
let future = dial.map_err(Error::Transport)
|
|
||||||
.and_then(move |stream| {
|
|
||||||
trace!("connected to {}", address);
|
|
||||||
if use_tls { // begin TLS session
|
if use_tls { // begin TLS session
|
||||||
let dns_name = dns_name.expect("for use_tls we have checked that dns_name is some");
|
let dns_name = dns_name.expect("for use_tls we have checked that dns_name is some");
|
||||||
trace!("starting TLS handshake with {}", address);
|
trace!("starting TLS handshake with {}", address);
|
||||||
let future = tls_config.client.connect(dns_name.as_ref(), stream)
|
let stream = self.tls_config.client.connect(dns_name.as_ref(), stream)
|
||||||
.map_err(move |e| {
|
.map_err(|e| {
|
||||||
debug!("TLS handshake with {} failed: {}", address, e);
|
debug!("TLS handshake with {} failed: {}", address, e);
|
||||||
Error::Tls(tls::Error::from(e))
|
Error::Tls(tls::Error::from(e))
|
||||||
})
|
})
|
||||||
.map(|s| EitherOutput::First(EitherOutput::First(s)));
|
.await?;
|
||||||
return Either::Left(future)
|
|
||||||
}
|
|
||||||
// continue with plain stream
|
|
||||||
Either::Right(future::ok(EitherOutput::Second(stream)))
|
|
||||||
})
|
|
||||||
.and_then(move |stream| {
|
|
||||||
trace!("sending websocket handshake request to {}", address1);
|
|
||||||
let mut client = handshake::Client::new(host_port, path);
|
|
||||||
if use_deflate {
|
|
||||||
client.add_extension(Box::new(Deflate::new(Mode::Client)));
|
|
||||||
}
|
|
||||||
Framed::new(stream, client)
|
|
||||||
.send(())
|
|
||||||
.map_err(|e| Error::Handshake(Box::new(e)))
|
|
||||||
.and_then(move |framed| {
|
|
||||||
trace!("awaiting websocket handshake response form {}", address2);
|
|
||||||
framed.into_future().map_err(|(e, _)| Error::Base(Box::new(e)))
|
|
||||||
})
|
|
||||||
.and_then(move |(response, framed)| {
|
|
||||||
match response {
|
|
||||||
None => {
|
|
||||||
debug!("connection to {} terminated during handshake", address1);
|
|
||||||
let e: io::Error = io::ErrorKind::ConnectionAborted.into();
|
|
||||||
return Err(Error::Handshake(Box::new(e)))
|
|
||||||
}
|
|
||||||
Some(Response::Redirect(r)) => {
|
|
||||||
debug!("received {}", r);
|
|
||||||
return Ok(Either::Left(r))
|
|
||||||
}
|
|
||||||
Some(Response::Accepted(_)) => {
|
|
||||||
trace!("websocket handshake with {} successful", address1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let (mut handshake, mut c) = new_connection(framed, max_data_size, Mode::Client);
|
|
||||||
c.add_extensions(handshake.drain_extensions());
|
|
||||||
Ok(Either::Right(BytesConnection { inner: c }))
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
Either::Right(future)
|
let stream: TlsOrPlain<_> = EitherOutput::First(EitherOutput::First(stream));
|
||||||
|
stream
|
||||||
|
} else { // continue with plain stream
|
||||||
|
EitherOutput::Second(stream)
|
||||||
|
};
|
||||||
|
|
||||||
|
trace!("sending websocket handshake request to {}", address);
|
||||||
|
|
||||||
|
let mut client = handshake::Client::new(stream, &host_port, path.as_ref());
|
||||||
|
|
||||||
|
if self.use_deflate {
|
||||||
|
client.add_extension(Box::new(Deflate::new(connection::Mode::Client)));
|
||||||
|
}
|
||||||
|
|
||||||
|
match client.handshake().map_err(|e| Error::Handshake(Box::new(e))).await? {
|
||||||
|
handshake::ServerResponse::Redirect { status_code, location } => {
|
||||||
|
debug!("received redirect ({}); location: {}", status_code, location);
|
||||||
|
Ok(Either::Left(location))
|
||||||
|
}
|
||||||
|
handshake::ServerResponse::Rejected { status_code } => {
|
||||||
|
let msg = format!("server rejected handshake; status code = {}", status_code);
|
||||||
|
Err(Error::Handshake(msg.into()))
|
||||||
|
}
|
||||||
|
handshake::ServerResponse::Accepted { .. } => {
|
||||||
|
trace!("websocket handshake with {} successful", address);
|
||||||
|
Ok(Either::Right(BytesConnection(client.into_connection())))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract host, port and optionally the DNS name from the given [`Multiaddr`].
|
// Extract host, port and optionally the DNS name from the given [`Multiaddr`].
|
||||||
@ -395,61 +396,50 @@ fn location_to_multiaddr<T>(location: &str) -> Result<Multiaddr, Error<T>> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a `Connection` from an existing `Framed` value.
|
|
||||||
fn new_connection<T, C>(framed: Framed<T, C>, max_size: u64, mode: Mode) -> (C, Connection<T>)
|
|
||||||
where
|
|
||||||
T: AsyncRead + AsyncWrite
|
|
||||||
{
|
|
||||||
let mut codec = base::Codec::new();
|
|
||||||
codec.set_max_data_size(max_size);
|
|
||||||
let old = framed.into_parts();
|
|
||||||
let mut new = FramedParts::new(old.io, codec);
|
|
||||||
new.read_buf = old.read_buf;
|
|
||||||
new.write_buf = old.write_buf;
|
|
||||||
let framed = Framed::from_parts(new);
|
|
||||||
let mut conn = Connection::from_framed(framed, mode);
|
|
||||||
conn.set_max_buffer_size(usize::try_from(max_size).unwrap_or(std::usize::MAX));
|
|
||||||
(old.codec, conn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BytesConnection ////////////////////////////////////////////////////////////////////////////////
|
// BytesConnection ////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
/// A [`Stream`] and [`Sink`] that produces and consumes [`BytesMut`] values
|
/// A [`Stream`] and [`Sink`] that produces and consumes [`BytesMut`] values
|
||||||
/// which correspond to the payload data of websocket frames.
|
/// which correspond to the payload data of websocket frames.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct BytesConnection<T> {
|
pub struct BytesConnection<T>(Connection<TlsOrPlain<T>>);
|
||||||
inner: Connection<EitherOutput<EitherOutput<client::TlsStream<T>, server::TlsStream<T>>, T>>
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite> Stream for BytesConnection<T> {
|
impl<T: AsyncRead + AsyncWrite + Unpin> Stream for BytesConnection<T> {
|
||||||
type Item = Result<BytesMut, io::Error>;
|
type Item = io::Result<BytesMut>;
|
||||||
|
|
||||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||||
let data = ready!(self.inner.poll(cx).map_err(|e| io::Error::new(io::ErrorKind::Other, e)));
|
let next = Pin::new(&mut self.0)
|
||||||
Poll::Ready(data.map(base::Data::into_bytes))
|
.poll_next(cx)
|
||||||
|
.map(|item| {
|
||||||
|
item.map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e)))
|
||||||
|
});
|
||||||
|
Poll::Ready(ready!(next).map(|result| result.map(connection::Data::into)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: AsyncRead + AsyncWrite> Sink<BytesMut> for BytesConnection<T> {
|
impl<T: AsyncRead + AsyncWrite + Unpin> Sink<BytesMut> for BytesConnection<T> {
|
||||||
type Error = io::Error;
|
type Error = io::Error;
|
||||||
|
|
||||||
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
Sink::poll_ready(Pin::new(&mut self.inner), cx)
|
Pin::new(&mut self.0)
|
||||||
|
.poll_ready(cx)
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start_send(self: Pin<&mut Self>, item: BytesMut) -> Result<(), Self::Error> {
|
fn start_send(mut self: Pin<&mut Self>, item: BytesMut) -> io::Result<()> {
|
||||||
self.inner.start_send(base::Data::Binary(item))
|
Pin::new(&mut self.0)
|
||||||
|
.start_send(connection::Data::Binary(item))
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
Sink::poll_flush(Pin::new(&mut self.inner), cx)
|
Pin::new(&mut self.0)
|
||||||
|
.poll_flush(cx)
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
|
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
|
||||||
Sink::poll_close(Pin::new(&mut self.inner), cx)
|
Pin::new(&mut self.0)
|
||||||
|
.poll_close(cx)
|
||||||
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
.map_err(|e| io::Error::new(io::ErrorKind::Other, e))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -60,12 +60,12 @@ impl<T> WsConfig<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Get the max. frame data size we support.
|
/// Get the max. frame data size we support.
|
||||||
pub fn max_data_size(&self) -> u64 {
|
pub fn max_data_size(&self) -> usize {
|
||||||
self.transport.max_data_size()
|
self.transport.max_data_size()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the max. frame data size we support.
|
/// Set the max. frame data size we support.
|
||||||
pub fn set_max_data_size(&mut self, size: u64) -> &mut Self {
|
pub fn set_max_data_size(&mut self, size: usize) -> &mut Self {
|
||||||
self.transport.set_max_data_size(size);
|
self.transport.set_max_data_size(size);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
@ -96,9 +96,9 @@ where
|
|||||||
T: Transport + Send + Clone + 'static,
|
T: Transport + Send + Clone + 'static,
|
||||||
T::Error: Send + 'static,
|
T::Error: Send + 'static,
|
||||||
T::Dial: Send + 'static,
|
T::Dial: Send + 'static,
|
||||||
T::Listener: Send + 'static,
|
T::Listener: Send + Unpin + 'static,
|
||||||
T::ListenerUpgrade: Send + 'static,
|
T::ListenerUpgrade: Send + 'static,
|
||||||
T::Output: AsyncRead + AsyncWrite + Send + 'static
|
T::Output: AsyncRead + AsyncWrite + Unpin + Send + 'static
|
||||||
{
|
{
|
||||||
type Output = RwStreamSink<BytesConnection<T::Output>>;
|
type Output = RwStreamSink<BytesConnection<T::Output>>;
|
||||||
type Error = Error<T::Error>;
|
type Error = Error<T::Error>;
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
// DEALINGS IN THE SOFTWARE.
|
// DEALINGS IN THE SOFTWARE.
|
||||||
|
|
||||||
use std::{fmt, io, sync::Arc};
|
use std::{fmt, io, sync::Arc};
|
||||||
use tokio_rustls::{
|
use futures_rustls::{
|
||||||
TlsConnector,
|
TlsConnector,
|
||||||
TlsAcceptor,
|
TlsAcceptor,
|
||||||
rustls,
|
rustls,
|
||||||
|
Reference in New Issue
Block a user